Traktor/myenv/Lib/site-packages/torchaudio/models/wav2vec2/components.py

1168 lines
46 KiB
Python
Raw Normal View History

2024-05-26 05:12:46 +02:00
import logging
from typing import List, Optional, Tuple
import torch
from torch import nn, Tensor
from torch.nn import Module, Parameter
from .wavlm_attention import WavLMSelfAttention
_LG = logging.getLogger(__name__)
def _init_transformer_params(module):
"""
Initialize the weights of Transformer module in Wav2Vec2/HuBERT.
If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02.
If ``bias`` is set to ``True`` in the module, set ``bias`` to 0.
If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02.
If ``padding_idx`` is not None, set the weight of padding to 0.
Note:
Ths method corresponds to
`init_bert_params
<https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/transformer_sentence_encoder.py#L21>`__
in the original ``fairseq`` implementation.
"""
def normal_(data):
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class LayerNorm(nn.LayerNorm):
"""Layer norm with transpose"""
def forward(self, input: Tensor) -> Tensor:
x = input.transpose(-2, -1)
x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.transpose(-2, -1)
return x
class ConvLayerBlock(Module):
"""Convolution unit of FeatureExtractor"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
bias: bool,
layer_norm: Optional[Module],
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.layer_norm = layer_norm
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=bias,
)
def forward(
self,
x: Tensor,
length: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
length (Tensor or None, optional): Shape ``[batch, ]``.
Returns:
Tensor: Shape ``[batch, out_channels, out_frames]``.
Optional[Tensor]: Shape ``[batch, ]``.
"""
x = self.conv(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
x = nn.functional.gelu(x)
if length is not None:
length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1
# When input length is 0, the resulting length can be negative. So fix it here.
length = torch.max(torch.zeros_like(length), length)
return x, length
class FeatureExtractor(Module):
"""Extract features from audio
Args:
conv_layers (nn.ModuleList):
convolution layers
"""
def __init__(
self,
conv_layers: nn.ModuleList,
):
super().__init__()
self.conv_layers = conv_layers
def forward(
self,
x: Tensor,
length: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor):
Input Tensor representing a batch of audio,
shape: ``[batch, time]``.
length (Tensor or None, optional):
Valid length of each input sample. shape: ``[batch, ]``.
Returns:
Tensor:
The resulting feature, shape: ``[batch, frame, feature]``
Optional[Tensor]:
Valid length of each output sample. shape: ``[batch, ]``.
"""
if x.ndim != 2:
raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}")
x = x.unsqueeze(1) # (batch, channel==1, frame)
for layer in self.conv_layers:
x, length = layer(x, length) # (batch, feature, frame)
x = x.transpose(1, 2) # (batch, frame, feature)
return x, length
class FeatureProjection(Module):
"""Layer that connects FeatureExtractor and Encoder
Projects features to encoder dimension.
Args:
in_features (int): Input feature dim.
out_features (int): Output feature dim.
dropout (float): Dropout probability.
"""
def __init__(
self,
in_features: int,
out_features: int,
dropout: float,
):
super().__init__()
self.layer_norm = nn.LayerNorm(in_features)
self.projection = nn.Linear(
in_features,
out_features,
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x (Tensor):
Feature Tensor. shape: ``[batch, frame, in_feature]``
Returns:
Tensor: Projected features. ``[batch, frame, out_feature]``.
"""
x = self.layer_norm(x)
x = self.projection(x)
x = self.dropout(x)
return x
class ConvolutionalPositionalEmbedding(Module):
"""Positional embedding which is placed at the beginning of Transformer.
Args:
embed_dim (int): Feature dimension of the input Tensor.
kernel_size (int): The number of frames to be use.
groups (int): The number of groups in feature dimensions.
"""
def __init__(
self,
embed_dim: int,
kernel_size: int,
groups: int,
):
super().__init__()
self.embed_dim = embed_dim
self.kernel_size = kernel_size
self.conv = nn.Conv1d(
in_channels=embed_dim,
out_channels=embed_dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=groups,
)
self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
def __prepare_scriptable__(self):
if self.conv.__class__.__name__ == "ParametrizedConv1d":
_LG.warning("Removing weight_norm from %s", self.__class__.__name__)
torch.nn.utils.parametrize.remove_parametrizations(self.conv, "weight")
return self
def forward(self, x):
"""
Args:
x (Tensor): shape ``[batch, frame, feature]``.
Returns:
Tensor: The resulting feature. Shape ``[batch, frame, feature]``.
"""
x = x.transpose(-2, -1)
x = self.conv(x)
if self.num_remove > 0:
x = x[..., : -self.num_remove]
x = torch.nn.functional.gelu(x)
x = x.transpose(-2, -1)
return x
class SelfAttention(Module):
"""Multihead Self Attention module
Args:
embed_dim (int): Total dimension of the model.
num_heads (int): The number of heads.
dropout (float, optional):
Dropout probability on attn_output_weights. Default: ``0.0``
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
):
super().__init__()
head_dim = embed_dim // num_heads
if head_dim * num_heads != embed_dim:
raise ValueError(f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`")
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = head_dim
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
position_bias: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``.
attention_mask (Tensor or ``None``, optional):
shape: ``[batch_size, 1, sequence_length, sequence_length]``
position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`.
key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with
:py:class:`WavLMSelfAttention`.
Returns:
(Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility
with :py:class:`WavLMSelAttention`).
Attention output shape: ``[batch, sequence_length, embed_dim]``.
"""
if x.ndim != 3 or x.shape[2] != self.embed_dim:
raise ValueError(
f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}."
)
batch_size, length, embed_dim = x.size()
if attention_mask is not None:
shape_ = (batch_size, 1, length, length)
if attention_mask.size() != shape_:
raise ValueError(f"The expected attention mask shape is {shape_}. " f"Found {attention_mask.size()}.")
shape = (batch_size, length, self.num_heads, self.head_dim)
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
k = self.k_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
dropout = self.dropout if self.training else 0.0
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False
)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
output = self.out_proj(attn_output)
return output, None # Necessary for compatibility with WavLMSelAttention
class FeedForward(Module):
"""Layer that follows attention layer in encoder layer."""
def __init__(
self,
io_features: int,
intermediate_features: int,
intermediate_dropout: float,
output_dropout: float,
):
super().__init__()
self.intermediate_dense = nn.Linear(io_features, intermediate_features)
self.intermediate_dropout = nn.Dropout(intermediate_dropout)
self.output_dense = nn.Linear(intermediate_features, io_features)
self.output_dropout = nn.Dropout(output_dropout)
def forward(self, x):
"""
Args:
x (Tensor): shape: `(batch, sequence_length, io_features)`
Returns:
x (Tensor): shape: `(batch, sequence_length, io_features)`
"""
x = self.intermediate_dense(x)
x = torch.nn.functional.gelu(x)
x = self.intermediate_dropout(x)
x = self.output_dense(x)
x = self.output_dropout(x)
return x
class EncoderLayer(Module):
"""A layer unit in encoder. Combines multihead self attention and feed forward."""
def __init__(
self,
attention: Module,
dropout: float,
layer_norm_first: bool,
feed_forward: Module,
):
super().__init__()
self.attention = attention
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(attention.embed_dim)
self.layer_norm_first = layer_norm_first
self.feed_forward = feed_forward
self.final_layer_norm = nn.LayerNorm(attention.embed_dim)
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
position_bias: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``.
attention_mask (Tensor or ``None``, optional): attention mask
of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``)
position_bias (Tensor or ``None``, optional): position bias of shape
``(batch_size * num_heads, src_len, src_len)``.
Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``)
key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``.
Only used for WavLM model, ignored otherwise. (Default: ``None``)
Returns:
(x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model,
``None`` otherwise.
"""
residual = x
if self.layer_norm_first:
x = self.layer_norm(x)
x, position_bias = self.attention(
x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask
)
x = self.dropout(x)
x = residual + x
if self.layer_norm_first:
x = x + self.feed_forward(self.final_layer_norm(x))
else:
x = self.layer_norm(x)
x = self.final_layer_norm(x + self.feed_forward(x))
return x, position_bias
class Transformer(Module):
def __init__(
self,
pos_conv_embed: Module,
dropout: float,
layers: Module,
layer_norm_first: bool,
layer_drop: float,
):
super().__init__()
self.pos_conv_embed = pos_conv_embed
self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim)
self.layer_norm_first = layer_norm_first
self.layer_drop = layer_drop
self.dropout = nn.Dropout(dropout)
self.layers = layers
def _preprocess(self, x: Tensor):
x = x + self.pos_conv_embed(x)
if self.layer_norm_first:
x = self.layer_norm(x)
x = self.dropout(x)
return x
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
position_bias: Optional[Tensor] = None,
) -> Tensor:
x = self._preprocess(x)
for layer in self.layers:
if not (self.training and torch.rand(1).item() <= self.layer_drop):
x, position_bias = layer(x, attention_mask, position_bias=position_bias)
if not self.layer_norm_first:
x = self.layer_norm(x)
return x
def get_intermediate_outputs(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> List[Tensor]:
if num_layers is not None:
if not 0 < num_layers <= len(self.layers):
raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]")
ret: List[Tensor] = []
position_bias = None
x = self._preprocess(x)
for layer in self.layers:
x, position_bias = layer(x, attention_mask, position_bias=position_bias)
ret.append(x)
if num_layers is not None and len(ret) >= num_layers:
return ret
return ret
class Encoder(Module):
def __init__(
self,
feature_projection: Module,
transformer: Module,
):
super().__init__()
self.feature_projection = feature_projection
self.transformer = transformer
def _preprocess(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
x = self.feature_projection(features)
mask: Optional[Tensor] = None
if lengths is not None:
batch_size, max_len, _ = x.shape
# create mask for padded elements and zero-out them
mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
x[mask] = 0.0
# extend the mask to attention shape and set weight
mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
mask = mask.expand(batch_size, 1, max_len, max_len)
return x, mask
def forward(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
) -> Tensor:
x, mask = self._preprocess(features, lengths)
x = self.transformer(x, attention_mask=mask)
return x
def extract_features(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> List[Tensor]:
x, masks = self._preprocess(features, lengths)
return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
################################################################################
def _get_feature_extractor(
norm_mode: str,
shapes: List[Tuple[int, int, int]],
bias: bool,
) -> FeatureExtractor:
"""
Args:
norm_mode (str):
Either "group_norm" or "layer_norm".
If "group_norm", then a single normalization is applied
in the first convolution block. Otherwise, all the convolution
blocks will have layer normalization.
This option corresponds to "extractor_mode" from fairseq.
Expected values are "group_norm" for Base arch, and
"layer_norm" for Large arch.
shapes (list of tuple of int):
Configuration of convolution layers. List of convolution configuration,
i.e. ``[(output_channel, kernel_size, stride), ...]``
This option corresponds to "conv_feature_layers" from fairseq.
Expected values are
``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2``
for all the architectures.
bias (bool):
Whether to include bias term to each convolution operation.
This option corresponds to "conv_bias" from fairseq.
Expected values are False for Base arch, and True for Large arch.
See Also:
* Original implementation
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733
* "extractor_mode"
- Def and base:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45
- Large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52
* "conv_feature_layers"
- Def, base and large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100
* "conv_bias"
- Def and base:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103
- Large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61
"""
if norm_mode not in ["group_norm", "layer_norm"]:
raise ValueError("Invalid norm mode")
blocks = []
in_channels = 1
for i, (out_channels, kernel_size, stride) in enumerate(shapes):
normalization = None
if norm_mode == "group_norm" and i == 0:
normalization = nn.GroupNorm(
num_groups=out_channels,
num_channels=out_channels,
affine=True,
)
elif norm_mode == "layer_norm":
normalization = LayerNorm(
normalized_shape=out_channels,
elementwise_affine=True,
)
blocks.append(
ConvLayerBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=bias,
layer_norm=normalization,
)
)
in_channels = out_channels
return FeatureExtractor(nn.ModuleList(blocks))
def _get_encoder(
in_features: int,
embed_dim: int,
dropout_input: float,
pos_conv_kernel: int,
pos_conv_groups: int,
num_layers: int,
num_heads: int,
attention_dropout: float,
ff_interm_features: int,
ff_interm_dropout: float,
dropout: float,
layer_norm_first: bool,
layer_drop: float,
) -> Encoder:
"""
Args:
in_features (int): The number of input features.
embed_dim (int):
The dimension of embedding.
This option corresponds to "encoder_embed_dim" from fairseq.
Expected values are 768 for Base arch, and 1024 for Large arch.
dropout_input (float):
The dropout probability applied after the input feature is projected
to ``embed_dim``.
This option corresponds to "dropout_input" from fairseq.
Expected values are 0.1 for both Base and Large arch.
pos_conv_kernel (int):
The kernel size of convolutional positional embeddings.
This option corresponds to "conv_pos" from fairseq.
Expected values are 128 for both Base and Large arch.
pos_conv_groups (int):
The number of groups of convolutional positional embeddings.
This option corresponds to "conv_pos_groups" from fairseq.
Expected values are 16 for both Base and Large arch.
num_layers (int):
The number of self attention layers in transformer block.
This option corresponds to "encoder_layers" from fairseq.
Expected values are 12 for Base and 24 for Large arch.
num_heads (int):
The number of heads in self attention layers.
This option corresponds to "encoder_attention_heads" from fairseq.
Expected values are 12 for Base and 16 for Large arch.
attention_dropout (float):
The dropout probability applied after softmax in self-attention layer.
This option corresponds to "attention_dropout" from fairseq.
Expected values are 0.1 for Base and 0.0 for Large arch.
ff_interm_features (int):
The dimension of hidden features in feed forward layer.
This option corresponds to "encoder_ffn_embed_dim" from fairseq.
Expected values are 3072 for Base and 4096 for Large arch.
ff_interm_dropout (float):
The dropout probability applied in feedforward layer.
This option correspinds to "activation_dropout" from fairseq.
Expected values are 0.1 for both Base and Large arch.
dropout (float):
The dropout probability applied at the end of feed forward layer.
This option corresponds to "dropout" from fairseq.
Expected values are 0.1 for Base and 0.0 for Large arch.
layer_norm_first (bool):
Control the order of layer norm in transformer layer and each encoder layer.
If True, in transformer layer, layer norm is applied before features are fed
to encoder layers. In encoder layer, two layer norms are applied before and after
self attention.
If False, in transformer layer, layer norm is applied after features are fed
to encoder layers. In encoder layer, two layer norms are applied after self
attention, before and after feed forward.
This option corresponds to "layer_norm_first" from fairseq.
Expected values are False for Base and True for Large arch.
layer_drop (float):
Probability to drop each encoder layer during training.
This option corresponds to "layerdrop" from fairseq.
Expected values are 0.1 for both Base and Large arch.
See Also:
* "encoder_embed_dim"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64
* "dropout_input"
- Def, base and large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78
* "conv_pos"
- Def, base and large
NOTE: The description is wrong.
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207
- Usage
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756
* "conv_pos_groups"
- Def, base and large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211
* "encoder_layers"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63
* "encoder_attention_heads"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66
* "attention_dropout"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60
* "encoder_ffn_embed_dim"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65
* "activation_dropout"
- Def
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71
- Base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55
* "dropout"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59
* "layer_norm_first"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53
* "layerdrop"
- Def
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74
- Base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54
"""
feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
# Original impl
# https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
encoder_layers = nn.ModuleList()
for _ in range(num_layers):
attention = SelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=attention_dropout,
)
feed_forward = FeedForward(
io_features=embed_dim,
intermediate_features=ff_interm_features,
intermediate_dropout=ff_interm_dropout,
output_dropout=dropout,
)
encoder_layers.append(
EncoderLayer(
attention=attention,
dropout=dropout,
layer_norm_first=layer_norm_first,
feed_forward=feed_forward,
)
)
transformer = Transformer(
pos_conv_embed=pos_conv,
dropout=dropout,
layers=encoder_layers,
layer_norm_first=not layer_norm_first,
layer_drop=layer_drop,
)
return Encoder(feature_projection, transformer)
def _get_wavlm_encoder(
in_features: int,
embed_dim: int,
dropout_input: float,
pos_conv_kernel: int,
pos_conv_groups: int,
num_layers: int,
num_heads: int,
num_buckets: int,
max_distance: int,
attention_dropout: float,
ff_interm_features: int,
ff_interm_dropout: float,
dropout: float,
layer_norm_first: bool,
layer_drop: float,
) -> Encoder:
"""
Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are
the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder
is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and
`max_distance`.
Args:
in_features (int): See :py:func:`_get_encoder`.
embed_dim (int): See :py:func:`_get_encoder`.
dropout_input (float): See :py:func:`_get_encoder`.
pos_conv_kernel (int): See :py:func:`_get_encoder`.
pos_conv_groups (int): See :py:func:`_get_encoder`.
num_layers (int): See :py:func:`_get_encoder`.
num_heads (int): See :py:func:`_get_encoder`.
num_buckets (int): Number of buckets for relative position embedding.
max_distance (int): Maximum distance for relative position embedding.
attention_dropout (float): See :py:func:`_get_encoder`.
ff_interm_features (int): See :py:func:`_get_encoder`.
ff_interm_dropout (float): See :py:func:`_get_encoder`.
dropout (float): See :py:func:`_get_encoder`.
layer_norm_first (bool): See :py:func:`_get_encoder`.
layer_drop (float): See :py:func:`_get_encoder`.
"""
feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
# Original impl
# https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
encoder_layers = nn.ModuleList()
for i in range(num_layers):
attention = WavLMSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_buckets=num_buckets,
max_distance=max_distance,
dropout=attention_dropout,
has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer.
)
feed_forward = FeedForward(
io_features=embed_dim,
intermediate_features=ff_interm_features,
intermediate_dropout=ff_interm_dropout,
output_dropout=dropout,
)
encoder_layers.append(
EncoderLayer(
attention=attention,
dropout=dropout,
layer_norm_first=layer_norm_first,
feed_forward=feed_forward,
)
)
transformer = Transformer(
pos_conv_embed=pos_conv,
dropout=dropout,
layers=encoder_layers,
layer_norm_first=not layer_norm_first,
layer_drop=layer_drop,
)
return Encoder(feature_projection, transformer)
def _compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
) -> Tensor:
"""Computes random mask spans for a given shape.
Args:
shape (int, int): The shape for which to compute masks.
The first element is batch size and second is the number of frames.
padding_mask (Tensor or None): The padding mask of the same dimension as shape,
which will prevent masking padded elements.
mask_prob (float): Probability for each token to be chosen as start of the span to be masked.
This will be multiplied by number of timesteps divided by length of mask span to mask
approximately this percentage of all elements. However due to overlaps, the actual number
will be smaller (unless no_overlap is True).
mask_type (str): How to compute mask lengths. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
``static``: Fixed size
``uniform``: Sample from uniform distribution [mask_other, mask_length*2]
``normal``: Sample from normal distribution with mean ``mask_length`` and stdev ``mask_other``.
``poisson``: Sample from possion distribution with lambda = ``mask_length``.
min_masks (int): Minimum number of masked spans.
no_overlap (bool): If false, will switch to an alternative recursive algorithm
that prevents spans from overlapping.
min_space (int): How many frames to keep unmasked between spans (Only used if no_overlap is True).
Returns:
(Tensor): The mask indices of dimension `[batch, frame]`.
"""
batch_size, frame = shape
mask = torch.full((batch_size, frame), False)
# add a random number for probabilistic rounding
all_num_mask = int(mask_prob * frame / float(mask_length) + torch.rand(1))
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(batch_size):
if padding_mask is not None:
sz = frame - padding_mask[i].long().sum().item()
# add a random number for probabilistic rounding
num_mask = int(mask_prob * sz / float(mask_length) + torch.rand(1))
num_mask = max(min_masks, num_mask)
else:
sz = frame
num_mask = all_num_mask
if mask_type == "static":
lengths = torch.full((num_mask,), mask_length)
elif mask_type == "uniform":
lengths = torch.randint(int(mask_other), mask_length * 2 + 1, size=(num_mask,))
elif mask_type == "normal":
lengths = torch.normal(mask_length, mask_other, size=(num_mask,))
lengths = torch.maximum(torch.ones(1), torch.round(lengths)).int()
elif mask_type == "poisson":
lengths = torch.poisson(mask_length, size=(num_mask,))
lengths = torch.round(lengths).int()
else:
raise Exception(f"unknown mask selection: {mask_type}")
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = torch.randint(s, e - length, size=(1,))
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - keep_length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = torch.tensor([e - s for s, e in parts], dtype=torch.int)
lens[lens < length + min_space] = 0
l_sum = lens.sum()
if l_sum == 0:
break
probs = lens / l_sum
c = torch.distributions.categorical.Categorical(probs).sample()
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = torch.tensor(mask_idc)
else:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = torch.randperm(sz - min_len)[:num_mask]
mask_idc = torch.tensor(
[mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]
)
mask_idcs.append(torch.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = mask_idc[torch.randperm(len(mask_idc))[:min_len].long()]
mask[i, mask_idc] = True
return mask
def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor:
"""Generate the padding mask given the padded input and the lengths Tensors.
Args:
input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`.
lengths (Tensor): The lengths Tensor of dimension `[batch,]`.
Returns:
(Tensor): The padding mask.
"""
batch_size, max_len, _ = input.shape
mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
return mask
class MaskGenerator(Module):
"""Generate the masks for masked prediction.
Args:
encoder_embed_dim (int): The dimension of the transformer embedding output.
mask_prob (float): Probability for each token to be chosen as start of the span to be masked.
This will be multiplied by number of timesteps divided by length of mask span to mask
approximately this percentage of all elements. However due to overlaps, the actual number
will be smaller (unless no_overlap is True).
mask_selection (str): How to choose the mask length.
Options: [``static``, ``uniform``, ``normal``, ``poisson``].
mask_other (float): Secondary mask argument (used for more complex distributions).
mask_length (int): The lengths of the mask.
no_mask_overlap (bool): Whether to allow masks to overlap.
mask_min_space (int): Minimum space between spans (if no overlap is enabled).
mask_channel_prob (float): The probability of replacing a feature with 0.
mask_channel_selection (str): How to choose the mask length for channel masking.
Options: [``static``, ``uniform``, ``normal``, ``poisson``].
mask_channel_other (float): Secondary mask argument for channel masking(used for more complex distributions).
mask_channel_length (int): Minimum space between spans (if no overlap is enabled) for channel masking.
no_mask_channel_overlap (bool): Whether to allow channel masks to overlap.
mask_channel_min_space (int): Minimum space between spans for channel masking(if no overlap is enabled).
"""
def __init__(
self,
encoder_embed_dim: int,
mask_prob: float,
mask_selection: str,
mask_other: float,
mask_length: int,
no_mask_overlap: bool,
mask_min_space: int,
mask_channel_prob: float,
mask_channel_selection: str,
mask_channel_other: float,
mask_channel_length: int,
no_mask_channel_overlap: bool,
mask_channel_min_space: int,
):
super().__init__()
self.mask_prob = mask_prob
self.mask_selection = mask_selection
self.mask_other = mask_other
self.mask_length = mask_length
self.no_mask_overlap = no_mask_overlap
self.mask_min_space = mask_min_space
self.mask_channel_prob = mask_channel_prob
self.mask_channel_selection = mask_channel_selection
self.mask_channel_other = mask_channel_other
self.mask_channel_length = mask_channel_length
self.no_mask_channel_overlap = no_mask_channel_overlap
self.mask_channel_min_space = mask_channel_min_space
self.mask_embedding = Parameter(torch.FloatTensor(encoder_embed_dim))
torch.nn.init.uniform_(self.mask_embedding)
def forward(self, x: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
"""
Args:
x (Tensor): The encoded representations after feature extraction module.
padding_mask (Tensor or None): The padding mask of the same dimension as shape,
which will prevent masking padded elements.
Returns:
Tensor: The feature representations after masking.
Tensor: The generated mask indices.
"""
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = _compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = mask_indices.to(x.device)
# change dtype of mask_embedding to x for mixed-precision training.
# see https://github.com/pytorch/audio/issues/2847 for details.
x[mask_indices] = self.mask_embedding.to(x.dtype)
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = _compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = mask_channel_indices.to(x.device).unsqueeze(1).expand(-1, T, -1)
x[mask_channel_indices] = 0
return x, mask_indices
def _compute_logits(
proj_x: Tensor,
target: Tensor,
label_embeddings: Parameter,
) -> Tensor:
"""Compute the logits of the embeddings.
Args:
proj_x (Tensor): The projected masked representations of dimension `[batch, frame, final_dim]`.
target (Tensor): The target Tensor of dimension `[batch, frame, final_dim]`.
label_embeddings (Parameter): The trainable embeddings of target of dimension `[num_class, final_dim]`.
Returns:
(Tensor): The logits of the inputs.
"""
logit_temp = 0.1
pos = torch.index_select(label_embeddings, 0, target.long())
negs = label_embeddings.unsqueeze(1).expand(-1, proj_x.size(0), -1)
neg_is_pos = (pos == negs).all(-1)
pos = pos.unsqueeze(0)
targets = torch.cat([pos, negs], dim=0)
logits = torch.cosine_similarity(proj_x.float(), targets.float(), dim=-1).type_as(proj_x)
logits /= logit_temp
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
return logits
class LogitGenerator(Module):
"""Generate the logits of masked and unmasked inputs.
Args:
encoder_embed_dim (int): The dimension of the transformer embedding output.
num_classes (int): The number of classes in the labels.
final_dim (int): Project final representations and targets to `final_dim`.
skip_masked (bool): If True, skip computing losses over masked frames.
skip_nomask (bool): If True, skip computing losses over unmasked frames.
"""
def __init__(
self,
encoder_embed_dim: int,
num_classes: int,
final_dim: int,
skip_masked: bool,
skip_nomask: bool,
):
super().__init__()
self.label_embeddings = Parameter(torch.FloatTensor(num_classes, final_dim))
torch.nn.init.uniform_(self.label_embeddings)
self.final_proj = torch.nn.Linear(encoder_embed_dim, final_dim)
self.skip_masked = skip_masked
self.skip_nomask = skip_nomask
def forward(self, x: Tensor, label: Tensor, mask_m: Tensor, mask_u: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
x (Tensor): The feature representation of the last transformer layer.
label (Tensor): The label Tensor of dimension `[batch, frame]`.
mask_m (Tensor): The masked indices of dimension `[batch, frame]`.
mask_u (Tensor): The unmasked indices of dimension `[batch, frame]`.
Returns:
Tensor: The logits of masked frames. Tensor of dimension `[masked_frame, final_dim]`.
Tensor: The logits of unmasked frames. Tensor of dimension `[unmasked_frame, final_dim]`.
"""
proj_x = self.final_proj(x)
if self.skip_masked:
logit_m = None
else:
proj_x_m = proj_x[mask_m]
label_m = label[mask_m]
logit_m = _compute_logits(proj_x_m, label_m, self.label_embeddings)
if self.skip_nomask:
logit_u = None
else:
proj_x_u = proj_x[mask_u]
label_u = label[mask_u]
logit_u = _compute_logits(proj_x_u, label_u, self.label_embeddings)
return logit_m, logit_u
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None