import math from typing import List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn, Tensor __all__ = [ "ResBlock", "MelResNet", "Stretch2d", "UpsampleNetwork", "WaveRNN", ] class ResBlock(nn.Module): r"""ResNet block based on *Efficient Neural Audio Synthesis* :cite:`kalchbrenner2018efficient`. Args: n_freq: the number of bins in a spectrogram. (Default: ``128``) Examples >>> resblock = ResBlock() >>> input = torch.rand(10, 128, 512) # a random spectrogram >>> output = resblock(input) # shape: (10, 128, 512) """ def __init__(self, n_freq: int = 128) -> None: super().__init__() self.resblock_model = nn.Sequential( nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), nn.BatchNorm1d(n_freq), nn.ReLU(inplace=True), nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), nn.BatchNorm1d(n_freq), ) def forward(self, specgram: Tensor) -> Tensor: r"""Pass the input through the ResBlock layer. Args: specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time). Return: Tensor shape: (n_batch, n_freq, n_time) """ return self.resblock_model(specgram) + specgram class MelResNet(nn.Module): r"""MelResNet layer uses a stack of ResBlocks on spectrogram. Args: n_res_block: the number of ResBlock in stack. (Default: ``10``) n_freq: the number of bins in a spectrogram. (Default: ``128``) n_hidden: the number of hidden dimensions of resblock. (Default: ``128``) n_output: the number of output dimensions of melresnet. (Default: ``128``) kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``) Examples >>> melresnet = MelResNet() >>> input = torch.rand(10, 128, 512) # a random spectrogram >>> output = melresnet(input) # shape: (10, 128, 508) """ def __init__( self, n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5 ) -> None: super().__init__() ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)] self.melresnet_model = nn.Sequential( nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False), nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True), *ResBlocks, nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1), ) def forward(self, specgram: Tensor) -> Tensor: r"""Pass the input through the MelResNet layer. Args: specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time). Return: Tensor shape: (n_batch, n_output, n_time - kernel_size + 1) """ return self.melresnet_model(specgram) class Stretch2d(nn.Module): r"""Upscale the frequency and time dimensions of a spectrogram. Args: time_scale: the scale factor in time dimension freq_scale: the scale factor in frequency dimension Examples >>> stretch2d = Stretch2d(time_scale=10, freq_scale=5) >>> input = torch.rand(10, 100, 512) # a random spectrogram >>> output = stretch2d(input) # shape: (10, 500, 5120) """ def __init__(self, time_scale: int, freq_scale: int) -> None: super().__init__() self.freq_scale = freq_scale self.time_scale = time_scale def forward(self, specgram: Tensor) -> Tensor: r"""Pass the input through the Stretch2d layer. Args: specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time). Return: Tensor shape: (..., n_freq * freq_scale, n_time * time_scale) """ return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1) class UpsampleNetwork(nn.Module): r"""Upscale the dimensions of a spectrogram. Args: upsample_scales: the list of upsample scales. n_res_block: the number of ResBlock in stack. (Default: ``10``) n_freq: the number of bins in a spectrogram. (Default: ``128``) n_hidden: the number of hidden dimensions of resblock. (Default: ``128``) n_output: the number of output dimensions of melresnet. (Default: ``128``) kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``) Examples >>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16]) >>> input = torch.rand(10, 128, 10) # a random spectrogram >>> output = upsamplenetwork(input) # shape: (10, 128, 1536), (10, 128, 1536) """ def __init__( self, upsample_scales: List[int], n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5, ) -> None: super().__init__() total_scale = 1 for upsample_scale in upsample_scales: total_scale *= upsample_scale self.total_scale: int = total_scale self.indent = (kernel_size - 1) // 2 * total_scale self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) self.resnet_stretch = Stretch2d(total_scale, 1) up_layers = [] for scale in upsample_scales: stretch = Stretch2d(scale, 1) conv = nn.Conv2d( in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False ) torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1)) up_layers.append(stretch) up_layers.append(conv) self.upsample_layers = nn.Sequential(*up_layers) def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]: r"""Pass the input through the UpsampleNetwork layer. Args: specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time) Return: Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale), (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) where total_scale is the product of all elements in upsample_scales. """ resnet_output = self.resnet(specgram).unsqueeze(1) resnet_output = self.resnet_stretch(resnet_output) resnet_output = resnet_output.squeeze(1) specgram = specgram.unsqueeze(1) upsampling_output = self.upsample_layers(specgram) upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent : -self.indent] return upsampling_output, resnet_output class WaveRNN(nn.Module): r"""WaveRNN model from *Efficient Neural Audio Synthesis* :cite:`wavernn` based on the implementation from `fatchord/WaveRNN `_. The original implementation was introduced in *Efficient Neural Audio Synthesis* :cite:`kalchbrenner2018efficient`. The input channels of waveform and spectrogram have to be 1. The product of `upsample_scales` must equal `hop_length`. See Also: * `Training example `__ * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model. Args: upsample_scales: the list of upsample scales. n_classes: the number of output classes. hop_length: the number of samples between the starts of consecutive frames. n_res_block: the number of ResBlock in stack. (Default: ``10``) n_rnn: the dimension of RNN layer. (Default: ``512``) n_fc: the dimension of fully connected layer. (Default: ``512``) kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``) n_freq: the number of bins in a spectrogram. (Default: ``128``) n_hidden: the number of hidden dimensions of resblock. (Default: ``128``) n_output: the number of output dimensions of melresnet. (Default: ``128``) Example >>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200) >>> waveform, sample_rate = torchaudio.load(file) >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) >>> output = wavernn(waveform, specgram) >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes) """ def __init__( self, upsample_scales: List[int], n_classes: int, hop_length: int, n_res_block: int = 10, n_rnn: int = 512, n_fc: int = 512, kernel_size: int = 5, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, ) -> None: super().__init__() self.kernel_size = kernel_size self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2 self.n_rnn = n_rnn self.n_aux = n_output // 4 self.hop_length = hop_length self.n_classes = n_classes self.n_bits: int = int(math.log2(self.n_classes)) total_scale = 1 for upsample_scale in upsample_scales: total_scale *= upsample_scale if total_scale != self.hop_length: raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") self.upsample = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True) self.relu1 = nn.ReLU(inplace=True) self.relu2 = nn.ReLU(inplace=True) self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc) self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc) self.fc3 = nn.Linear(n_fc, self.n_classes) def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: r"""Pass the input through the WaveRNN model. Args: waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length) specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time) Return: Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes) """ if waveform.size(1) != 1: raise ValueError("Require the input channel of waveform is 1") if specgram.size(1) != 1: raise ValueError("Require the input channel of specgram is 1") # remove channel dimension until the end waveform, specgram = waveform.squeeze(1), specgram.squeeze(1) batch_size = waveform.size(0) h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) # output of upsample: # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale) # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) specgram, aux = self.upsample(specgram) specgram = specgram.transpose(1, 2) aux = aux.transpose(1, 2) aux_idx = [self.n_aux * i for i in range(5)] a1 = aux[:, :, aux_idx[0] : aux_idx[1]] a2 = aux[:, :, aux_idx[1] : aux_idx[2]] a3 = aux[:, :, aux_idx[2] : aux_idx[3]] a4 = aux[:, :, aux_idx[3] : aux_idx[4]] x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1) x = self.fc(x) res = x x, _ = self.rnn1(x, h1) x = x + res res = x x = torch.cat([x, a2], dim=-1) x, _ = self.rnn2(x, h2) x = x + res x = torch.cat([x, a3], dim=-1) x = self.fc1(x) x = self.relu1(x) x = torch.cat([x, a4], dim=-1) x = self.fc2(x) x = self.relu2(x) x = self.fc3(x) # bring back channel dimension return x.unsqueeze(1) @torch.jit.export def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r"""Inference method of WaveRNN. This function currently only supports multinomial sampling, which assumes the network is trained on cross entropy loss. Args: specgram (Tensor): Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`. lengths (Tensor or None, optional): Indicates the valid length of each audio in the batch. Shape: `(batch, )`. When the ``specgram`` contains spectrograms with different durations, by providing ``lengths`` argument, the model will compute the corresponding valid output lengths. If ``None``, it is assumed that all the audio in ``waveforms`` have valid length. Default: ``None``. Returns: (Tensor, Optional[Tensor]): Tensor The inferred waveform of size `(n_batch, 1, n_time)`. 1 stands for a single channel. Tensor or None If ``lengths`` argument was provided, a Tensor of shape `(batch, )` is returned. It indicates the valid length in time axis of the output Tensor. """ device = specgram.device dtype = specgram.dtype specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad)) specgram, aux = self.upsample(specgram) if lengths is not None: lengths = lengths * self.upsample.total_scale output: List[Tensor] = [] b_size, _, seq_len = specgram.size() h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype) h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype) x = torch.zeros((b_size, 1), device=device, dtype=dtype) aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)] for i in range(seq_len): m_t = specgram[:, :, i] a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split] x = torch.cat([x, m_t, a1_t], dim=1) x = self.fc(x) _, h1 = self.rnn1(x.unsqueeze(1), h1) x = x + h1[0] inp = torch.cat([x, a2_t], dim=1) _, h2 = self.rnn2(inp.unsqueeze(1), h2) x = x + h2[0] x = torch.cat([x, a3_t], dim=1) x = F.relu(self.fc1(x)) x = torch.cat([x, a4_t], dim=1) x = F.relu(self.fc2(x)) logits = self.fc3(x) posterior = F.softmax(logits, dim=1) x = torch.multinomial(posterior, 1).float() # Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1] x = 2 * x / (2**self.n_bits - 1.0) - 1.0 output.append(x) return torch.stack(output).permute(1, 2, 0), lengths