331 lines
12 KiB
Python
331 lines
12 KiB
Python
"""Implements Conv-TasNet with building blocks of it.
|
||
|
||
Based on https://github.com/naplab/Conv-TasNet/tree/e66d82a8f956a69749ec8a4ae382217faa097c5c
|
||
"""
|
||
|
||
from typing import Optional, Tuple
|
||
|
||
import torch
|
||
|
||
|
||
class ConvBlock(torch.nn.Module):
|
||
"""1D Convolutional block.
|
||
|
||
Args:
|
||
io_channels (int): The number of input/output channels, <B, Sc>
|
||
hidden_channels (int): The number of channels in the internal layers, <H>.
|
||
kernel_size (int): The convolution kernel size of the middle layer, <P>.
|
||
padding (int): Padding value of the convolution in the middle layer.
|
||
dilation (int, optional): Dilation value of the convolution in the middle layer.
|
||
no_redisual (bool, optional): Disable residual block/output.
|
||
|
||
Note:
|
||
This implementation corresponds to the "non-causal" setting in the paper.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
io_channels: int,
|
||
hidden_channels: int,
|
||
kernel_size: int,
|
||
padding: int,
|
||
dilation: int = 1,
|
||
no_residual: bool = False,
|
||
):
|
||
super().__init__()
|
||
|
||
self.conv_layers = torch.nn.Sequential(
|
||
torch.nn.Conv1d(in_channels=io_channels, out_channels=hidden_channels, kernel_size=1),
|
||
torch.nn.PReLU(),
|
||
torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
|
||
torch.nn.Conv1d(
|
||
in_channels=hidden_channels,
|
||
out_channels=hidden_channels,
|
||
kernel_size=kernel_size,
|
||
padding=padding,
|
||
dilation=dilation,
|
||
groups=hidden_channels,
|
||
),
|
||
torch.nn.PReLU(),
|
||
torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
|
||
)
|
||
|
||
self.res_out = (
|
||
None
|
||
if no_residual
|
||
else torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
|
||
)
|
||
self.skip_out = torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
|
||
|
||
def forward(self, input: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
|
||
feature = self.conv_layers(input)
|
||
if self.res_out is None:
|
||
residual = None
|
||
else:
|
||
residual = self.res_out(feature)
|
||
skip_out = self.skip_out(feature)
|
||
return residual, skip_out
|
||
|
||
|
||
class MaskGenerator(torch.nn.Module):
|
||
"""TCN (Temporal Convolution Network) Separation Module
|
||
|
||
Generates masks for separation.
|
||
|
||
Args:
|
||
input_dim (int): Input feature dimension, <N>.
|
||
num_sources (int): The number of sources to separate.
|
||
kernel_size (int): The convolution kernel size of conv blocks, <P>.
|
||
num_featrs (int): Input/output feature dimenstion of conv blocks, <B, Sc>.
|
||
num_hidden (int): Intermediate feature dimention of conv blocks, <H>
|
||
num_layers (int): The number of conv blocks in one stack, <X>.
|
||
num_stacks (int): The number of conv block stacks, <R>.
|
||
msk_activate (str): The activation function of the mask output.
|
||
|
||
Note:
|
||
This implementation corresponds to the "non-causal" setting in the paper.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
input_dim: int,
|
||
num_sources: int,
|
||
kernel_size: int,
|
||
num_feats: int,
|
||
num_hidden: int,
|
||
num_layers: int,
|
||
num_stacks: int,
|
||
msk_activate: str,
|
||
):
|
||
super().__init__()
|
||
|
||
self.input_dim = input_dim
|
||
self.num_sources = num_sources
|
||
|
||
self.input_norm = torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8)
|
||
self.input_conv = torch.nn.Conv1d(in_channels=input_dim, out_channels=num_feats, kernel_size=1)
|
||
|
||
self.receptive_field = 0
|
||
self.conv_layers = torch.nn.ModuleList([])
|
||
for s in range(num_stacks):
|
||
for l in range(num_layers):
|
||
multi = 2**l
|
||
self.conv_layers.append(
|
||
ConvBlock(
|
||
io_channels=num_feats,
|
||
hidden_channels=num_hidden,
|
||
kernel_size=kernel_size,
|
||
dilation=multi,
|
||
padding=multi,
|
||
# The last ConvBlock does not need residual
|
||
no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)),
|
||
)
|
||
)
|
||
self.receptive_field += kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
|
||
self.output_prelu = torch.nn.PReLU()
|
||
self.output_conv = torch.nn.Conv1d(
|
||
in_channels=num_feats,
|
||
out_channels=input_dim * num_sources,
|
||
kernel_size=1,
|
||
)
|
||
if msk_activate == "sigmoid":
|
||
self.mask_activate = torch.nn.Sigmoid()
|
||
elif msk_activate == "relu":
|
||
self.mask_activate = torch.nn.ReLU()
|
||
else:
|
||
raise ValueError(f"Unsupported activation {msk_activate}")
|
||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||
"""Generate separation mask.
|
||
|
||
Args:
|
||
input (torch.Tensor): 3D Tensor with shape [batch, features, frames]
|
||
|
||
Returns:
|
||
Tensor: shape [batch, num_sources, features, frames]
|
||
"""
|
||
batch_size = input.shape[0]
|
||
feats = self.input_norm(input)
|
||
feats = self.input_conv(feats)
|
||
output = 0.0
|
||
for layer in self.conv_layers:
|
||
residual, skip = layer(feats)
|
||
if residual is not None: # the last conv layer does not produce residual
|
||
feats = feats + residual
|
||
output = output + skip
|
||
output = self.output_prelu(output)
|
||
output = self.output_conv(output)
|
||
output = self.mask_activate(output)
|
||
return output.view(batch_size, self.num_sources, self.input_dim, -1)
|
||
|
||
|
||
class ConvTasNet(torch.nn.Module):
|
||
"""Conv-TasNet architecture introduced in
|
||
*Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
|
||
:cite:`Luo_2019`.
|
||
|
||
Note:
|
||
This implementation corresponds to the "non-causal" setting in the paper.
|
||
|
||
See Also:
|
||
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
|
||
|
||
Args:
|
||
num_sources (int, optional): The number of sources to split.
|
||
enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, <L>.
|
||
enc_num_feats (int, optional): The feature dimensions passed to mask generator, <N>.
|
||
msk_kernel_size (int, optional): The convolution kernel size of the mask generator, <P>.
|
||
msk_num_feats (int, optional): The input/output feature dimension of conv block in the mask generator, <B, Sc>.
|
||
msk_num_hidden_feats (int, optional): The internal feature dimension of conv block of the mask generator, <H>.
|
||
msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, <X>.
|
||
msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, <R>.
|
||
msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``).
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
num_sources: int = 2,
|
||
# encoder/decoder parameters
|
||
enc_kernel_size: int = 16,
|
||
enc_num_feats: int = 512,
|
||
# mask generator parameters
|
||
msk_kernel_size: int = 3,
|
||
msk_num_feats: int = 128,
|
||
msk_num_hidden_feats: int = 512,
|
||
msk_num_layers: int = 8,
|
||
msk_num_stacks: int = 3,
|
||
msk_activate: str = "sigmoid",
|
||
):
|
||
super().__init__()
|
||
|
||
self.num_sources = num_sources
|
||
self.enc_num_feats = enc_num_feats
|
||
self.enc_kernel_size = enc_kernel_size
|
||
self.enc_stride = enc_kernel_size // 2
|
||
|
||
self.encoder = torch.nn.Conv1d(
|
||
in_channels=1,
|
||
out_channels=enc_num_feats,
|
||
kernel_size=enc_kernel_size,
|
||
stride=self.enc_stride,
|
||
padding=self.enc_stride,
|
||
bias=False,
|
||
)
|
||
self.mask_generator = MaskGenerator(
|
||
input_dim=enc_num_feats,
|
||
num_sources=num_sources,
|
||
kernel_size=msk_kernel_size,
|
||
num_feats=msk_num_feats,
|
||
num_hidden=msk_num_hidden_feats,
|
||
num_layers=msk_num_layers,
|
||
num_stacks=msk_num_stacks,
|
||
msk_activate=msk_activate,
|
||
)
|
||
self.decoder = torch.nn.ConvTranspose1d(
|
||
in_channels=enc_num_feats,
|
||
out_channels=1,
|
||
kernel_size=enc_kernel_size,
|
||
stride=self.enc_stride,
|
||
padding=self.enc_stride,
|
||
bias=False,
|
||
)
|
||
|
||
def _align_num_frames_with_strides(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
||
"""Pad input Tensor so that the end of the input tensor corresponds with
|
||
|
||
1. (if kernel size is odd) the center of the last convolution kernel
|
||
or 2. (if kernel size is even) the end of the first half of the last convolution kernel
|
||
|
||
Assumption:
|
||
The resulting Tensor will be padded with the size of stride (== kernel_width // 2)
|
||
on the both ends in Conv1D
|
||
|
||
|<--- k_1 --->|
|
||
| | |<-- k_n-1 -->|
|
||
| | | |<--- k_n --->|
|
||
| | | | |
|
||
| | | | |
|
||
| v v v |
|
||
|<---->|<--- input signal --->|<--->|<---->|
|
||
stride PAD stride
|
||
|
||
Args:
|
||
input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
|
||
|
||
Returns:
|
||
Tensor: Padded Tensor
|
||
int: Number of paddings performed
|
||
"""
|
||
batch_size, num_channels, num_frames = input.shape
|
||
is_odd = self.enc_kernel_size % 2
|
||
num_strides = (num_frames - is_odd) // self.enc_stride
|
||
num_remainings = num_frames - (is_odd + num_strides * self.enc_stride)
|
||
if num_remainings == 0:
|
||
return input, 0
|
||
|
||
num_paddings = self.enc_stride - num_remainings
|
||
pad = torch.zeros(
|
||
batch_size,
|
||
num_channels,
|
||
num_paddings,
|
||
dtype=input.dtype,
|
||
device=input.device,
|
||
)
|
||
return torch.cat([input, pad], 2), num_paddings
|
||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||
"""Perform source separation. Generate audio source waveforms.
|
||
|
||
Args:
|
||
input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames]
|
||
|
||
Returns:
|
||
Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
|
||
"""
|
||
if input.ndim != 3 or input.shape[1] != 1:
|
||
raise ValueError(f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}")
|
||
|
||
# B: batch size
|
||
# L: input frame length
|
||
# L': padded input frame length
|
||
# F: feature dimension
|
||
# M: feature frame length
|
||
# S: number of sources
|
||
|
||
padded, num_pads = self._align_num_frames_with_strides(input) # B, 1, L'
|
||
batch_size, num_padded_frames = padded.shape[0], padded.shape[2]
|
||
feats = self.encoder(padded) # B, F, M
|
||
masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M
|
||
masked = masked.view(batch_size * self.num_sources, self.enc_num_feats, -1) # B*S, F, M
|
||
decoded = self.decoder(masked) # B*S, 1, L'
|
||
output = decoded.view(batch_size, self.num_sources, num_padded_frames) # B, S, L'
|
||
if num_pads > 0:
|
||
output = output[..., :-num_pads] # B, S, L
|
||
return output
|
||
|
||
|
||
def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet:
|
||
r"""Builds non-causal version of :class:`~torchaudio.models.ConvTasNet`.
|
||
|
||
The parameter settings follow the ones with the highest Si-SNR metirc score in the paper,
|
||
except the mask activation function is changed from "sigmoid" to "relu" for performance improvement.
|
||
|
||
Args:
|
||
num_sources (int, optional): Number of sources in the output.
|
||
(Default: 2)
|
||
Returns:
|
||
ConvTasNet:
|
||
ConvTasNet model.
|
||
"""
|
||
return ConvTasNet(
|
||
num_sources=num_sources,
|
||
enc_kernel_size=16,
|
||
enc_num_feats=512,
|
||
msk_kernel_size=3,
|
||
msk_num_feats=128,
|
||
msk_num_hidden_feats=512,
|
||
msk_num_layers=8,
|
||
msk_num_stacks=3,
|
||
msk_activate="relu",
|
||
)
|