817 lines
35 KiB
Python
817 lines
35 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
from torchaudio.models import Emformer
|
|
|
|
|
|
__all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
|
|
|
|
|
|
class _TimeReduction(torch.nn.Module):
|
|
r"""Coalesces frames along time dimension into a
|
|
fewer number of frames with higher feature dimensionality.
|
|
|
|
Args:
|
|
stride (int): number of frames to merge for each output frame.
|
|
"""
|
|
|
|
def __init__(self, stride: int) -> None:
|
|
super().__init__()
|
|
self.stride = stride
|
|
|
|
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""Forward pass.
|
|
|
|
B: batch size;
|
|
T: maximum input sequence length in batch;
|
|
D: feature dimension of each input sequence frame.
|
|
|
|
Args:
|
|
input (torch.Tensor): input sequences, with shape `(B, T, D)`.
|
|
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in ``input``.
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor):
|
|
torch.Tensor
|
|
output sequences, with shape
|
|
`(B, T // stride, D * stride)`
|
|
torch.Tensor
|
|
output lengths, with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in output sequences.
|
|
"""
|
|
B, T, D = input.shape
|
|
num_frames = T - (T % self.stride)
|
|
input = input[:, :num_frames, :]
|
|
lengths = lengths.div(self.stride, rounding_mode="trunc")
|
|
T_max = num_frames // self.stride
|
|
|
|
output = input.reshape(B, T_max, D * self.stride)
|
|
output = output.contiguous()
|
|
return output, lengths
|
|
|
|
|
|
class _CustomLSTM(torch.nn.Module):
|
|
r"""Custom long-short-term memory (LSTM) block that applies layer normalization
|
|
to internal nodes.
|
|
|
|
Args:
|
|
input_dim (int): input dimension.
|
|
hidden_dim (int): hidden dimension.
|
|
layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``)
|
|
layer_norm_epsilon (float, optional): value of epsilon to use in
|
|
layer normalization layers (Default: 1e-5)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
layer_norm: bool = False,
|
|
layer_norm_epsilon: float = 1e-5,
|
|
) -> None:
|
|
super().__init__()
|
|
self.x2g = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=(not layer_norm))
|
|
self.p2g = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=False)
|
|
if layer_norm:
|
|
self.c_norm = torch.nn.LayerNorm(hidden_dim, eps=layer_norm_epsilon)
|
|
self.g_norm = torch.nn.LayerNorm(4 * hidden_dim, eps=layer_norm_epsilon)
|
|
else:
|
|
self.c_norm = torch.nn.Identity()
|
|
self.g_norm = torch.nn.Identity()
|
|
|
|
self.hidden_dim = hidden_dim
|
|
|
|
def forward(
|
|
self, input: torch.Tensor, state: Optional[List[torch.Tensor]]
|
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
r"""Forward pass.
|
|
|
|
B: batch size;
|
|
T: maximum sequence length in batch;
|
|
D: feature dimension of each input sequence element.
|
|
|
|
Args:
|
|
input (torch.Tensor): with shape `(T, B, D)`.
|
|
state (List[torch.Tensor] or None): list of tensors
|
|
representing internal state generated in preceding invocation
|
|
of ``forward``.
|
|
|
|
Returns:
|
|
(torch.Tensor, List[torch.Tensor]):
|
|
torch.Tensor
|
|
output, with shape `(T, B, hidden_dim)`.
|
|
List[torch.Tensor]
|
|
list of tensors representing internal state generated
|
|
in current invocation of ``forward``.
|
|
"""
|
|
if state is None:
|
|
B = input.size(1)
|
|
h = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
|
|
c = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
|
|
else:
|
|
h, c = state
|
|
|
|
gated_input = self.x2g(input)
|
|
outputs = []
|
|
for gates in gated_input.unbind(0):
|
|
gates = gates + self.p2g(h)
|
|
gates = self.g_norm(gates)
|
|
input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)
|
|
input_gate = input_gate.sigmoid()
|
|
forget_gate = forget_gate.sigmoid()
|
|
cell_gate = cell_gate.tanh()
|
|
output_gate = output_gate.sigmoid()
|
|
c = forget_gate * c + input_gate * cell_gate
|
|
c = self.c_norm(c)
|
|
h = output_gate * c.tanh()
|
|
outputs.append(h)
|
|
|
|
output = torch.stack(outputs, dim=0)
|
|
state = [h, c]
|
|
|
|
return output, state
|
|
|
|
|
|
class _Transcriber(ABC):
|
|
@abstractmethod
|
|
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def infer(
|
|
self,
|
|
input: torch.Tensor,
|
|
lengths: torch.Tensor,
|
|
states: Optional[List[List[torch.Tensor]]],
|
|
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
pass
|
|
|
|
|
|
class _EmformerEncoder(torch.nn.Module, _Transcriber):
|
|
r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
|
|
|
|
Args:
|
|
input_dim (int): feature dimension of each input sequence element.
|
|
output_dim (int): feature dimension of each output sequence element.
|
|
segment_length (int): length of input segment expressed as number of frames.
|
|
right_context_length (int): length of right context expressed as number of frames.
|
|
time_reduction_input_dim (int): dimension to scale each element in input sequences to
|
|
prior to applying time reduction block.
|
|
time_reduction_stride (int): factor by which to reduce length of input sequence.
|
|
transformer_num_heads (int): number of attention heads in each Emformer layer.
|
|
transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
|
|
transformer_num_layers (int): number of Emformer layers to instantiate.
|
|
transformer_left_context_length (int): length of left context.
|
|
transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0)
|
|
transformer_activation (str, optional): activation function to use in each Emformer layer's
|
|
feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
|
|
transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
|
|
transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
|
|
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
|
|
transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
input_dim: int,
|
|
output_dim: int,
|
|
segment_length: int,
|
|
right_context_length: int,
|
|
time_reduction_input_dim: int,
|
|
time_reduction_stride: int,
|
|
transformer_num_heads: int,
|
|
transformer_ffn_dim: int,
|
|
transformer_num_layers: int,
|
|
transformer_left_context_length: int,
|
|
transformer_dropout: float = 0.0,
|
|
transformer_activation: str = "relu",
|
|
transformer_max_memory_size: int = 0,
|
|
transformer_weight_init_scale_strategy: str = "depthwise",
|
|
transformer_tanh_on_mem: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.input_linear = torch.nn.Linear(
|
|
input_dim,
|
|
time_reduction_input_dim,
|
|
bias=False,
|
|
)
|
|
self.time_reduction = _TimeReduction(time_reduction_stride)
|
|
transformer_input_dim = time_reduction_input_dim * time_reduction_stride
|
|
self.transformer = Emformer(
|
|
transformer_input_dim,
|
|
transformer_num_heads,
|
|
transformer_ffn_dim,
|
|
transformer_num_layers,
|
|
segment_length // time_reduction_stride,
|
|
dropout=transformer_dropout,
|
|
activation=transformer_activation,
|
|
left_context_length=transformer_left_context_length,
|
|
right_context_length=right_context_length // time_reduction_stride,
|
|
max_memory_size=transformer_max_memory_size,
|
|
weight_init_scale_strategy=transformer_weight_init_scale_strategy,
|
|
tanh_on_mem=transformer_tanh_on_mem,
|
|
)
|
|
self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
|
|
self.layer_norm = torch.nn.LayerNorm(output_dim)
|
|
|
|
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""Forward pass for training.
|
|
|
|
B: batch size;
|
|
T: maximum input sequence length in batch;
|
|
D: feature dimension of each input sequence frame (input_dim).
|
|
|
|
Args:
|
|
input (torch.Tensor): input frame sequences right-padded with right context, with
|
|
shape `(B, T + right context length, D)`.
|
|
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in ``input``.
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor):
|
|
torch.Tensor
|
|
output frame sequences, with
|
|
shape `(B, T // time_reduction_stride, output_dim)`.
|
|
torch.Tensor
|
|
output input lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements for i-th batch element in output frame sequences.
|
|
"""
|
|
input_linear_out = self.input_linear(input)
|
|
time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
|
|
transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
|
|
output_linear_out = self.output_linear(transformer_out)
|
|
layer_norm_out = self.layer_norm(output_linear_out)
|
|
return layer_norm_out, transformer_lengths
|
|
|
|
@torch.jit.export
|
|
def infer(
|
|
self,
|
|
input: torch.Tensor,
|
|
lengths: torch.Tensor,
|
|
states: Optional[List[List[torch.Tensor]]],
|
|
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
r"""Forward pass for inference.
|
|
|
|
B: batch size;
|
|
T: maximum input sequence segment length in batch;
|
|
D: feature dimension of each input sequence frame (input_dim).
|
|
|
|
Args:
|
|
input (torch.Tensor): input frame sequence segments right-padded with right context, with
|
|
shape `(B, T + right context length, D)`.
|
|
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in ``input``.
|
|
state (List[List[torch.Tensor]] or None): list of lists of tensors
|
|
representing internal state generated in preceding invocation
|
|
of ``infer``.
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
torch.Tensor
|
|
output frame sequences, with
|
|
shape `(B, T // time_reduction_stride, output_dim)`.
|
|
torch.Tensor
|
|
output input lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements for i-th batch element in output.
|
|
List[List[torch.Tensor]]
|
|
output states; list of lists of tensors
|
|
representing internal state generated in current invocation
|
|
of ``infer``.
|
|
"""
|
|
input_linear_out = self.input_linear(input)
|
|
time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
|
|
(
|
|
transformer_out,
|
|
transformer_lengths,
|
|
transformer_states,
|
|
) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states)
|
|
output_linear_out = self.output_linear(transformer_out)
|
|
layer_norm_out = self.layer_norm(output_linear_out)
|
|
return layer_norm_out, transformer_lengths, transformer_states
|
|
|
|
|
|
class _Predictor(torch.nn.Module):
|
|
r"""Recurrent neural network transducer (RNN-T) prediction network.
|
|
|
|
Args:
|
|
num_symbols (int): size of target token lexicon.
|
|
output_dim (int): feature dimension of each output sequence element.
|
|
symbol_embedding_dim (int): dimension of each target token embedding.
|
|
num_lstm_layers (int): number of LSTM layers to instantiate.
|
|
lstm_hidden_dim (int): output dimension of each LSTM layer.
|
|
lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
|
|
for LSTM layers. (Default: ``False``)
|
|
lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
|
|
LSTM layer normalization layers. (Default: 1e-5)
|
|
lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_symbols: int,
|
|
output_dim: int,
|
|
symbol_embedding_dim: int,
|
|
num_lstm_layers: int,
|
|
lstm_hidden_dim: int,
|
|
lstm_layer_norm: bool = False,
|
|
lstm_layer_norm_epsilon: float = 1e-5,
|
|
lstm_dropout: float = 0.0,
|
|
) -> None:
|
|
super().__init__()
|
|
self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim)
|
|
self.input_layer_norm = torch.nn.LayerNorm(symbol_embedding_dim)
|
|
self.lstm_layers = torch.nn.ModuleList(
|
|
[
|
|
_CustomLSTM(
|
|
symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
|
|
lstm_hidden_dim,
|
|
layer_norm=lstm_layer_norm,
|
|
layer_norm_epsilon=lstm_layer_norm_epsilon,
|
|
)
|
|
for idx in range(num_lstm_layers)
|
|
]
|
|
)
|
|
self.dropout = torch.nn.Dropout(p=lstm_dropout)
|
|
self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
|
|
self.output_layer_norm = torch.nn.LayerNorm(output_dim)
|
|
|
|
self.lstm_dropout = lstm_dropout
|
|
|
|
def forward(
|
|
self,
|
|
input: torch.Tensor,
|
|
lengths: torch.Tensor,
|
|
state: Optional[List[List[torch.Tensor]]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
r"""Forward pass.
|
|
|
|
B: batch size;
|
|
U: maximum sequence length in batch;
|
|
D: feature dimension of each input sequence element.
|
|
|
|
Args:
|
|
input (torch.Tensor): target sequences, with shape `(B, U)` and each element
|
|
mapping to a target symbol, i.e. in range `[0, num_symbols)`.
|
|
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in ``input``.
|
|
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
|
representing internal state generated in preceding invocation
|
|
of ``forward``. (Default: ``None``)
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
torch.Tensor
|
|
output encoding sequences, with shape `(B, U, output_dim)`
|
|
torch.Tensor
|
|
output lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements for i-th batch element in output encoding sequences.
|
|
List[List[torch.Tensor]]
|
|
output states; list of lists of tensors
|
|
representing internal state generated in current invocation of ``forward``.
|
|
"""
|
|
input_tb = input.permute(1, 0)
|
|
embedding_out = self.embedding(input_tb)
|
|
input_layer_norm_out = self.input_layer_norm(embedding_out)
|
|
|
|
lstm_out = input_layer_norm_out
|
|
state_out: List[List[torch.Tensor]] = []
|
|
for layer_idx, lstm in enumerate(self.lstm_layers):
|
|
lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
|
|
lstm_out = self.dropout(lstm_out)
|
|
state_out.append(lstm_state_out)
|
|
|
|
linear_out = self.linear(lstm_out)
|
|
output_layer_norm_out = self.output_layer_norm(linear_out)
|
|
return output_layer_norm_out.permute(1, 0, 2), lengths, state_out
|
|
|
|
|
|
class _Joiner(torch.nn.Module):
|
|
r"""Recurrent neural network transducer (RNN-T) joint network.
|
|
|
|
Args:
|
|
input_dim (int): source and target input dimension.
|
|
output_dim (int): output dimension.
|
|
activation (str, optional): activation function to use in the joiner.
|
|
Must be one of ("relu", "tanh"). (Default: "relu")
|
|
|
|
"""
|
|
|
|
def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
|
|
if activation == "relu":
|
|
self.activation = torch.nn.ReLU()
|
|
elif activation == "tanh":
|
|
self.activation = torch.nn.Tanh()
|
|
else:
|
|
raise ValueError(f"Unsupported activation {activation}")
|
|
|
|
def forward(
|
|
self,
|
|
source_encodings: torch.Tensor,
|
|
source_lengths: torch.Tensor,
|
|
target_encodings: torch.Tensor,
|
|
target_lengths: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
r"""Forward pass for training.
|
|
|
|
B: batch size;
|
|
T: maximum source sequence length in batch;
|
|
U: maximum target sequence length in batch;
|
|
D: dimension of each source and target sequence encoding.
|
|
|
|
Args:
|
|
source_encodings (torch.Tensor): source encoding sequences, with
|
|
shape `(B, T, D)`.
|
|
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
valid sequence length of i-th batch element in ``source_encodings``.
|
|
target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
|
|
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
valid sequence length of i-th batch element in ``target_encodings``.
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor, torch.Tensor):
|
|
torch.Tensor
|
|
joint network output, with shape `(B, T, U, output_dim)`.
|
|
torch.Tensor
|
|
output source lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements along dim 1 for i-th batch element in joint network output.
|
|
torch.Tensor
|
|
output target lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements along dim 2 for i-th batch element in joint network output.
|
|
"""
|
|
joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
|
|
activation_out = self.activation(joint_encodings)
|
|
output = self.linear(activation_out)
|
|
return output, source_lengths, target_lengths
|
|
|
|
|
|
class RNNT(torch.nn.Module):
|
|
r"""torchaudio.models.RNNT()
|
|
|
|
Recurrent neural network transducer (RNN-T) model.
|
|
|
|
Note:
|
|
To build the model, please use one of the factory functions.
|
|
|
|
See Also:
|
|
:class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models.
|
|
|
|
Args:
|
|
transcriber (torch.nn.Module): transcription network.
|
|
predictor (torch.nn.Module): prediction network.
|
|
joiner (torch.nn.Module): joint network.
|
|
"""
|
|
|
|
def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
|
|
super().__init__()
|
|
self.transcriber = transcriber
|
|
self.predictor = predictor
|
|
self.joiner = joiner
|
|
|
|
def forward(
|
|
self,
|
|
sources: torch.Tensor,
|
|
source_lengths: torch.Tensor,
|
|
targets: torch.Tensor,
|
|
target_lengths: torch.Tensor,
|
|
predictor_state: Optional[List[List[torch.Tensor]]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
r"""Forward pass for training.
|
|
|
|
B: batch size;
|
|
T: maximum source sequence length in batch;
|
|
U: maximum target sequence length in batch;
|
|
D: feature dimension of each source sequence element.
|
|
|
|
Args:
|
|
sources (torch.Tensor): source frame sequences right-padded with right context, with
|
|
shape `(B, T, D)`.
|
|
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in ``sources``.
|
|
targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
|
|
mapping to a target symbol.
|
|
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in ``targets``.
|
|
predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
|
representing prediction network internal state generated in preceding invocation
|
|
of ``forward``. (Default: ``None``)
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
torch.Tensor
|
|
joint network output, with shape
|
|
`(B, max output source length, max output target length, output_dim (number of target symbols))`.
|
|
torch.Tensor
|
|
output source lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements along dim 1 for i-th batch element in joint network output.
|
|
torch.Tensor
|
|
output target lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements along dim 2 for i-th batch element in joint network output.
|
|
List[List[torch.Tensor]]
|
|
output states; list of lists of tensors
|
|
representing prediction network internal state generated in current invocation
|
|
of ``forward``.
|
|
"""
|
|
source_encodings, source_lengths = self.transcriber(
|
|
input=sources,
|
|
lengths=source_lengths,
|
|
)
|
|
target_encodings, target_lengths, predictor_state = self.predictor(
|
|
input=targets,
|
|
lengths=target_lengths,
|
|
state=predictor_state,
|
|
)
|
|
output, source_lengths, target_lengths = self.joiner(
|
|
source_encodings=source_encodings,
|
|
source_lengths=source_lengths,
|
|
target_encodings=target_encodings,
|
|
target_lengths=target_lengths,
|
|
)
|
|
|
|
return (
|
|
output,
|
|
source_lengths,
|
|
target_lengths,
|
|
predictor_state,
|
|
)
|
|
|
|
@torch.jit.export
|
|
def transcribe_streaming(
|
|
self,
|
|
sources: torch.Tensor,
|
|
source_lengths: torch.Tensor,
|
|
state: Optional[List[List[torch.Tensor]]],
|
|
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
r"""Applies transcription network to sources in streaming mode.
|
|
|
|
B: batch size;
|
|
T: maximum source sequence segment length in batch;
|
|
D: feature dimension of each source sequence frame.
|
|
|
|
Args:
|
|
sources (torch.Tensor): source frame sequence segments right-padded with right context, with
|
|
shape `(B, T + right context length, D)`.
|
|
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in ``sources``.
|
|
state (List[List[torch.Tensor]] or None): list of lists of tensors
|
|
representing transcription network internal state generated in preceding invocation
|
|
of ``transcribe_streaming``.
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
torch.Tensor
|
|
output frame sequences, with
|
|
shape `(B, T // time_reduction_stride, output_dim)`.
|
|
torch.Tensor
|
|
output lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements for i-th batch element in output.
|
|
List[List[torch.Tensor]]
|
|
output states; list of lists of tensors
|
|
representing transcription network internal state generated in current invocation
|
|
of ``transcribe_streaming``.
|
|
"""
|
|
return self.transcriber.infer(sources, source_lengths, state)
|
|
|
|
@torch.jit.export
|
|
def transcribe(
|
|
self,
|
|
sources: torch.Tensor,
|
|
source_lengths: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""Applies transcription network to sources in non-streaming mode.
|
|
|
|
B: batch size;
|
|
T: maximum source sequence length in batch;
|
|
D: feature dimension of each source sequence frame.
|
|
|
|
Args:
|
|
sources (torch.Tensor): source frame sequences right-padded with right context, with
|
|
shape `(B, T + right context length, D)`.
|
|
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in ``sources``.
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor):
|
|
torch.Tensor
|
|
output frame sequences, with
|
|
shape `(B, T // time_reduction_stride, output_dim)`.
|
|
torch.Tensor
|
|
output lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements for i-th batch element in output frame sequences.
|
|
"""
|
|
return self.transcriber(sources, source_lengths)
|
|
|
|
@torch.jit.export
|
|
def predict(
|
|
self,
|
|
targets: torch.Tensor,
|
|
target_lengths: torch.Tensor,
|
|
state: Optional[List[List[torch.Tensor]]],
|
|
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
r"""Applies prediction network to targets.
|
|
|
|
B: batch size;
|
|
U: maximum target sequence length in batch;
|
|
D: feature dimension of each target sequence frame.
|
|
|
|
Args:
|
|
targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
|
|
mapping to a target symbol, i.e. in range `[0, num_symbols)`.
|
|
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
number of valid frames for i-th batch element in ``targets``.
|
|
state (List[List[torch.Tensor]] or None): list of lists of tensors
|
|
representing internal state generated in preceding invocation
|
|
of ``predict``.
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
torch.Tensor
|
|
output frame sequences, with shape `(B, U, output_dim)`.
|
|
torch.Tensor
|
|
output lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements for i-th batch element in output.
|
|
List[List[torch.Tensor]]
|
|
output states; list of lists of tensors
|
|
representing internal state generated in current invocation of ``predict``.
|
|
"""
|
|
return self.predictor(input=targets, lengths=target_lengths, state=state)
|
|
|
|
@torch.jit.export
|
|
def join(
|
|
self,
|
|
source_encodings: torch.Tensor,
|
|
source_lengths: torch.Tensor,
|
|
target_encodings: torch.Tensor,
|
|
target_lengths: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
r"""Applies joint network to source and target encodings.
|
|
|
|
B: batch size;
|
|
T: maximum source sequence length in batch;
|
|
U: maximum target sequence length in batch;
|
|
D: dimension of each source and target sequence encoding.
|
|
|
|
Args:
|
|
source_encodings (torch.Tensor): source encoding sequences, with
|
|
shape `(B, T, D)`.
|
|
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
valid sequence length of i-th batch element in ``source_encodings``.
|
|
target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
|
|
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
valid sequence length of i-th batch element in ``target_encodings``.
|
|
|
|
Returns:
|
|
(torch.Tensor, torch.Tensor, torch.Tensor):
|
|
torch.Tensor
|
|
joint network output, with shape `(B, T, U, output_dim)`.
|
|
torch.Tensor
|
|
output source lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements along dim 1 for i-th batch element in joint network output.
|
|
torch.Tensor
|
|
output target lengths, with shape `(B,)` and i-th element representing
|
|
number of valid elements along dim 2 for i-th batch element in joint network output.
|
|
"""
|
|
output, source_lengths, target_lengths = self.joiner(
|
|
source_encodings=source_encodings,
|
|
source_lengths=source_lengths,
|
|
target_encodings=target_encodings,
|
|
target_lengths=target_lengths,
|
|
)
|
|
return output, source_lengths, target_lengths
|
|
|
|
|
|
def emformer_rnnt_model(
|
|
*,
|
|
input_dim: int,
|
|
encoding_dim: int,
|
|
num_symbols: int,
|
|
segment_length: int,
|
|
right_context_length: int,
|
|
time_reduction_input_dim: int,
|
|
time_reduction_stride: int,
|
|
transformer_num_heads: int,
|
|
transformer_ffn_dim: int,
|
|
transformer_num_layers: int,
|
|
transformer_dropout: float,
|
|
transformer_activation: str,
|
|
transformer_left_context_length: int,
|
|
transformer_max_memory_size: int,
|
|
transformer_weight_init_scale_strategy: str,
|
|
transformer_tanh_on_mem: bool,
|
|
symbol_embedding_dim: int,
|
|
num_lstm_layers: int,
|
|
lstm_layer_norm: bool,
|
|
lstm_layer_norm_epsilon: float,
|
|
lstm_dropout: float,
|
|
) -> RNNT:
|
|
r"""Builds Emformer-based :class:`~torchaudio.models.RNNT`.
|
|
|
|
Note:
|
|
For non-streaming inference, the expectation is for `transcribe` to be called on input
|
|
sequences right-concatenated with `right_context_length` frames.
|
|
|
|
For streaming inference, the expectation is for `transcribe_streaming` to be called
|
|
on input chunks comprising `segment_length` frames right-concatenated with `right_context_length`
|
|
frames.
|
|
|
|
Args:
|
|
input_dim (int): dimension of input sequence frames passed to transcription network.
|
|
encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
|
|
passed to joint network.
|
|
num_symbols (int): cardinality of set of target tokens.
|
|
segment_length (int): length of input segment expressed as number of frames.
|
|
right_context_length (int): length of right context expressed as number of frames.
|
|
time_reduction_input_dim (int): dimension to scale each element in input sequences to
|
|
prior to applying time reduction block.
|
|
time_reduction_stride (int): factor by which to reduce length of input sequence.
|
|
transformer_num_heads (int): number of attention heads in each Emformer layer.
|
|
transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
|
|
transformer_num_layers (int): number of Emformer layers to instantiate.
|
|
transformer_left_context_length (int): length of left context considered by Emformer.
|
|
transformer_dropout (float): Emformer dropout probability.
|
|
transformer_activation (str): activation function to use in each Emformer layer's
|
|
feedforward network. Must be one of ("relu", "gelu", "silu").
|
|
transformer_max_memory_size (int): maximum number of memory elements to use.
|
|
transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling
|
|
strategy. Must be one of ("depthwise", "constant", ``None``).
|
|
transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements.
|
|
symbol_embedding_dim (int): dimension of each target token embedding.
|
|
num_lstm_layers (int): number of LSTM layers to instantiate.
|
|
lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
|
|
lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
|
|
lstm_dropout (float): LSTM dropout probability.
|
|
|
|
Returns:
|
|
RNNT:
|
|
Emformer RNN-T model.
|
|
"""
|
|
encoder = _EmformerEncoder(
|
|
input_dim=input_dim,
|
|
output_dim=encoding_dim,
|
|
segment_length=segment_length,
|
|
right_context_length=right_context_length,
|
|
time_reduction_input_dim=time_reduction_input_dim,
|
|
time_reduction_stride=time_reduction_stride,
|
|
transformer_num_heads=transformer_num_heads,
|
|
transformer_ffn_dim=transformer_ffn_dim,
|
|
transformer_num_layers=transformer_num_layers,
|
|
transformer_dropout=transformer_dropout,
|
|
transformer_activation=transformer_activation,
|
|
transformer_left_context_length=transformer_left_context_length,
|
|
transformer_max_memory_size=transformer_max_memory_size,
|
|
transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy,
|
|
transformer_tanh_on_mem=transformer_tanh_on_mem,
|
|
)
|
|
predictor = _Predictor(
|
|
num_symbols,
|
|
encoding_dim,
|
|
symbol_embedding_dim=symbol_embedding_dim,
|
|
num_lstm_layers=num_lstm_layers,
|
|
lstm_hidden_dim=symbol_embedding_dim,
|
|
lstm_layer_norm=lstm_layer_norm,
|
|
lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
|
|
lstm_dropout=lstm_dropout,
|
|
)
|
|
joiner = _Joiner(encoding_dim, num_symbols)
|
|
return RNNT(encoder, predictor, joiner)
|
|
|
|
|
|
def emformer_rnnt_base(num_symbols: int) -> RNNT:
|
|
r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.
|
|
|
|
Args:
|
|
num_symbols (int): The size of target token lexicon.
|
|
|
|
Returns:
|
|
RNNT:
|
|
Emformer RNN-T model.
|
|
"""
|
|
return emformer_rnnt_model(
|
|
input_dim=80,
|
|
encoding_dim=1024,
|
|
num_symbols=num_symbols,
|
|
segment_length=16,
|
|
right_context_length=4,
|
|
time_reduction_input_dim=128,
|
|
time_reduction_stride=4,
|
|
transformer_num_heads=8,
|
|
transformer_ffn_dim=2048,
|
|
transformer_num_layers=20,
|
|
transformer_dropout=0.1,
|
|
transformer_activation="gelu",
|
|
transformer_left_context_length=30,
|
|
transformer_max_memory_size=0,
|
|
transformer_weight_init_scale_strategy="depthwise",
|
|
transformer_tanh_on_mem=True,
|
|
symbol_embedding_dim=512,
|
|
num_lstm_layers=3,
|
|
lstm_layer_norm=True,
|
|
lstm_layer_norm_epsilon=1e-3,
|
|
lstm_dropout=0.3,
|
|
)
|