381 lines
13 KiB
Python
381 lines
13 KiB
Python
|
import json
|
||
|
import math
|
||
|
from abc import ABC, abstractmethod
|
||
|
from dataclasses import dataclass
|
||
|
from functools import partial
|
||
|
from typing import Callable, List, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torchaudio
|
||
|
from torchaudio._internal import module_utils
|
||
|
from torchaudio.models import emformer_rnnt_base, RNNT, RNNTBeamSearch
|
||
|
|
||
|
|
||
|
__all__ = []
|
||
|
|
||
|
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
|
||
|
_gain = pow(10, 0.05 * _decibel)
|
||
|
|
||
|
|
||
|
def _piecewise_linear_log(x):
|
||
|
x[x > math.e] = torch.log(x[x > math.e])
|
||
|
x[x <= math.e] = x[x <= math.e] / math.e
|
||
|
return x
|
||
|
|
||
|
|
||
|
class _FunctionalModule(torch.nn.Module):
|
||
|
def __init__(self, functional):
|
||
|
super().__init__()
|
||
|
self.functional = functional
|
||
|
|
||
|
def forward(self, input):
|
||
|
return self.functional(input)
|
||
|
|
||
|
|
||
|
class _GlobalStatsNormalization(torch.nn.Module):
|
||
|
def __init__(self, global_stats_path):
|
||
|
super().__init__()
|
||
|
|
||
|
with open(global_stats_path) as f:
|
||
|
blob = json.loads(f.read())
|
||
|
|
||
|
self.register_buffer("mean", torch.tensor(blob["mean"]))
|
||
|
self.register_buffer("invstddev", torch.tensor(blob["invstddev"]))
|
||
|
|
||
|
def forward(self, input):
|
||
|
return (input - self.mean) * self.invstddev
|
||
|
|
||
|
|
||
|
class _FeatureExtractor(ABC):
|
||
|
@abstractmethod
|
||
|
def __call__(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""Generates features and length output from the given input tensor.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): input tensor.
|
||
|
|
||
|
Returns:
|
||
|
(torch.Tensor, torch.Tensor):
|
||
|
torch.Tensor:
|
||
|
Features, with shape `(length, *)`.
|
||
|
torch.Tensor:
|
||
|
Length, with shape `(1,)`.
|
||
|
"""
|
||
|
|
||
|
|
||
|
class _TokenProcessor(ABC):
|
||
|
@abstractmethod
|
||
|
def __call__(self, tokens: List[int], **kwargs) -> str:
|
||
|
"""Decodes given list of tokens to text sequence.
|
||
|
|
||
|
Args:
|
||
|
tokens (List[int]): list of tokens to decode.
|
||
|
|
||
|
Returns:
|
||
|
str:
|
||
|
Decoded text sequence.
|
||
|
"""
|
||
|
|
||
|
|
||
|
class _ModuleFeatureExtractor(torch.nn.Module, _FeatureExtractor):
|
||
|
"""``torch.nn.Module``-based feature extraction pipeline.
|
||
|
|
||
|
Args:
|
||
|
pipeline (torch.nn.Module): module that implements feature extraction logic.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, pipeline: torch.nn.Module) -> None:
|
||
|
super().__init__()
|
||
|
self.pipeline = pipeline
|
||
|
|
||
|
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""Generates features and length output from the given input tensor.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): input tensor.
|
||
|
|
||
|
Returns:
|
||
|
(torch.Tensor, torch.Tensor):
|
||
|
torch.Tensor:
|
||
|
Features, with shape `(length, *)`.
|
||
|
torch.Tensor:
|
||
|
Length, with shape `(1,)`.
|
||
|
"""
|
||
|
features = self.pipeline(input)
|
||
|
length = torch.tensor([features.shape[0]])
|
||
|
return features, length
|
||
|
|
||
|
|
||
|
class _SentencePieceTokenProcessor(_TokenProcessor):
|
||
|
"""SentencePiece-model-based token processor.
|
||
|
|
||
|
Args:
|
||
|
sp_model_path (str): path to SentencePiece model.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, sp_model_path: str) -> None:
|
||
|
if not module_utils.is_module_available("sentencepiece"):
|
||
|
raise RuntimeError("SentencePiece is not available. Please install it.")
|
||
|
|
||
|
import sentencepiece as spm
|
||
|
|
||
|
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
|
||
|
self.post_process_remove_list = {
|
||
|
self.sp_model.unk_id(),
|
||
|
self.sp_model.eos_id(),
|
||
|
self.sp_model.pad_id(),
|
||
|
}
|
||
|
|
||
|
def __call__(self, tokens: List[int], lstrip: bool = True) -> str:
|
||
|
"""Decodes given list of tokens to text sequence.
|
||
|
|
||
|
Args:
|
||
|
tokens (List[int]): list of tokens to decode.
|
||
|
lstrip (bool, optional): if ``True``, returns text sequence with leading whitespace
|
||
|
removed. (Default: ``True``).
|
||
|
|
||
|
Returns:
|
||
|
str:
|
||
|
Decoded text sequence.
|
||
|
"""
|
||
|
filtered_hypo_tokens = [
|
||
|
token_index for token_index in tokens[1:] if token_index not in self.post_process_remove_list
|
||
|
]
|
||
|
output_string = "".join(self.sp_model.id_to_piece(filtered_hypo_tokens)).replace("\u2581", " ")
|
||
|
|
||
|
if lstrip:
|
||
|
return output_string.lstrip()
|
||
|
else:
|
||
|
return output_string
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class RNNTBundle:
|
||
|
"""Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text)
|
||
|
inference with an RNN-T model.
|
||
|
|
||
|
More specifically, the class provides methods that produce the featurization pipeline,
|
||
|
decoder wrapping the specified RNN-T model, and output token post-processor that together
|
||
|
constitute a complete end-to-end ASR inference pipeline that produces a text sequence
|
||
|
given a raw waveform.
|
||
|
|
||
|
It can support non-streaming (full-context) inference as well as streaming inference.
|
||
|
|
||
|
Users should not directly instantiate objects of this class; rather, users should use the
|
||
|
instances (representing pre-trained models) that exist within the module,
|
||
|
e.g. :data:`torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH`.
|
||
|
|
||
|
Example
|
||
|
>>> import torchaudio
|
||
|
>>> from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
|
||
|
>>> import torch
|
||
|
>>>
|
||
|
>>> # Non-streaming inference.
|
||
|
>>> # Build feature extractor, decoder with RNN-T model, and token processor.
|
||
|
>>> feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_feature_extractor()
|
||
|
100%|███████████████████████████████| 3.81k/3.81k [00:00<00:00, 4.22MB/s]
|
||
|
>>> decoder = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder()
|
||
|
Downloading: "https://download.pytorch.org/torchaudio/models/emformer_rnnt_base_librispeech.pt"
|
||
|
100%|███████████████████████████████| 293M/293M [00:07<00:00, 42.1MB/s]
|
||
|
>>> token_processor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_token_processor()
|
||
|
100%|███████████████████████████████| 295k/295k [00:00<00:00, 25.4MB/s]
|
||
|
>>>
|
||
|
>>> # Instantiate LibriSpeech dataset; retrieve waveform for first sample.
|
||
|
>>> dataset = torchaudio.datasets.LIBRISPEECH("/home/librispeech", url="test-clean")
|
||
|
>>> waveform = next(iter(dataset))[0].squeeze()
|
||
|
>>>
|
||
|
>>> with torch.no_grad():
|
||
|
>>> # Produce mel-scale spectrogram features.
|
||
|
>>> features, length = feature_extractor(waveform)
|
||
|
>>>
|
||
|
>>> # Generate top-10 hypotheses.
|
||
|
>>> hypotheses = decoder(features, length, 10)
|
||
|
>>>
|
||
|
>>> # For top hypothesis, convert predicted tokens to text.
|
||
|
>>> text = token_processor(hypotheses[0][0])
|
||
|
>>> print(text)
|
||
|
he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to [...]
|
||
|
>>>
|
||
|
>>>
|
||
|
>>> # Streaming inference.
|
||
|
>>> hop_length = EMFORMER_RNNT_BASE_LIBRISPEECH.hop_length
|
||
|
>>> num_samples_segment = EMFORMER_RNNT_BASE_LIBRISPEECH.segment_length * hop_length
|
||
|
>>> num_samples_segment_right_context = (
|
||
|
>>> num_samples_segment + EMFORMER_RNNT_BASE_LIBRISPEECH.right_context_length * hop_length
|
||
|
>>> )
|
||
|
>>>
|
||
|
>>> # Build streaming inference feature extractor.
|
||
|
>>> streaming_feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_streaming_feature_extractor()
|
||
|
>>>
|
||
|
>>> # Process same waveform as before, this time sequentially across overlapping segments
|
||
|
>>> # to simulate streaming inference. Note the usage of ``streaming_feature_extractor`` and ``decoder.infer``.
|
||
|
>>> state, hypothesis = None, None
|
||
|
>>> for idx in range(0, len(waveform), num_samples_segment):
|
||
|
>>> segment = waveform[idx: idx + num_samples_segment_right_context]
|
||
|
>>> segment = torch.nn.functional.pad(segment, (0, num_samples_segment_right_context - len(segment)))
|
||
|
>>> with torch.no_grad():
|
||
|
>>> features, length = streaming_feature_extractor(segment)
|
||
|
>>> hypotheses, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
|
||
|
>>> hypothesis = hypotheses[0]
|
||
|
>>> transcript = token_processor(hypothesis[0])
|
||
|
>>> if transcript:
|
||
|
>>> print(transcript, end=" ", flush=True)
|
||
|
he hoped there would be stew for dinner turn ips and car rots and bru 'd oes and fat mut ton pieces to [...]
|
||
|
"""
|
||
|
|
||
|
class FeatureExtractor(_FeatureExtractor):
|
||
|
"""Interface of the feature extraction part of RNN-T pipeline"""
|
||
|
|
||
|
class TokenProcessor(_TokenProcessor):
|
||
|
"""Interface of the token processor part of RNN-T pipeline"""
|
||
|
|
||
|
_rnnt_path: str
|
||
|
_rnnt_factory_func: Callable[[], RNNT]
|
||
|
_global_stats_path: str
|
||
|
_sp_model_path: str
|
||
|
_right_padding: int
|
||
|
_blank: int
|
||
|
_sample_rate: int
|
||
|
_n_fft: int
|
||
|
_n_mels: int
|
||
|
_hop_length: int
|
||
|
_segment_length: int
|
||
|
_right_context_length: int
|
||
|
|
||
|
def _get_model(self) -> RNNT:
|
||
|
model = self._rnnt_factory_func()
|
||
|
path = torchaudio.utils.download_asset(self._rnnt_path)
|
||
|
state_dict = torch.load(path)
|
||
|
model.load_state_dict(state_dict)
|
||
|
model.eval()
|
||
|
return model
|
||
|
|
||
|
@property
|
||
|
def sample_rate(self) -> int:
|
||
|
"""Sample rate (in cycles per second) of input waveforms.
|
||
|
|
||
|
:type: int
|
||
|
"""
|
||
|
return self._sample_rate
|
||
|
|
||
|
@property
|
||
|
def n_fft(self) -> int:
|
||
|
"""Size of FFT window to use.
|
||
|
|
||
|
:type: int
|
||
|
"""
|
||
|
return self._n_fft
|
||
|
|
||
|
@property
|
||
|
def n_mels(self) -> int:
|
||
|
"""Number of mel spectrogram features to extract from input waveforms.
|
||
|
|
||
|
:type: int
|
||
|
"""
|
||
|
return self._n_mels
|
||
|
|
||
|
@property
|
||
|
def hop_length(self) -> int:
|
||
|
"""Number of samples between successive frames in input expected by model.
|
||
|
|
||
|
:type: int
|
||
|
"""
|
||
|
return self._hop_length
|
||
|
|
||
|
@property
|
||
|
def segment_length(self) -> int:
|
||
|
"""Number of frames in segment in input expected by model.
|
||
|
|
||
|
:type: int
|
||
|
"""
|
||
|
return self._segment_length
|
||
|
|
||
|
@property
|
||
|
def right_context_length(self) -> int:
|
||
|
"""Number of frames in right contextual block in input expected by model.
|
||
|
|
||
|
:type: int
|
||
|
"""
|
||
|
return self._right_context_length
|
||
|
|
||
|
def get_decoder(self) -> RNNTBeamSearch:
|
||
|
"""Constructs RNN-T decoder.
|
||
|
|
||
|
Returns:
|
||
|
RNNTBeamSearch
|
||
|
"""
|
||
|
model = self._get_model()
|
||
|
return RNNTBeamSearch(model, self._blank)
|
||
|
|
||
|
def get_feature_extractor(self) -> FeatureExtractor:
|
||
|
"""Constructs feature extractor for non-streaming (full-context) ASR.
|
||
|
|
||
|
Returns:
|
||
|
FeatureExtractor
|
||
|
"""
|
||
|
local_path = torchaudio.utils.download_asset(self._global_stats_path)
|
||
|
return _ModuleFeatureExtractor(
|
||
|
torch.nn.Sequential(
|
||
|
torchaudio.transforms.MelSpectrogram(
|
||
|
sample_rate=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length
|
||
|
),
|
||
|
_FunctionalModule(lambda x: x.transpose(1, 0)),
|
||
|
_FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
|
||
|
_GlobalStatsNormalization(local_path),
|
||
|
_FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 0, 0, self._right_padding))),
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def get_streaming_feature_extractor(self) -> FeatureExtractor:
|
||
|
"""Constructs feature extractor for streaming (simultaneous) ASR.
|
||
|
|
||
|
Returns:
|
||
|
FeatureExtractor
|
||
|
"""
|
||
|
local_path = torchaudio.utils.download_asset(self._global_stats_path)
|
||
|
return _ModuleFeatureExtractor(
|
||
|
torch.nn.Sequential(
|
||
|
torchaudio.transforms.MelSpectrogram(
|
||
|
sample_rate=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length
|
||
|
),
|
||
|
_FunctionalModule(lambda x: x.transpose(1, 0)),
|
||
|
_FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
|
||
|
_GlobalStatsNormalization(local_path),
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def get_token_processor(self) -> TokenProcessor:
|
||
|
"""Constructs token processor.
|
||
|
|
||
|
Returns:
|
||
|
TokenProcessor
|
||
|
"""
|
||
|
local_path = torchaudio.utils.download_asset(self._sp_model_path)
|
||
|
return _SentencePieceTokenProcessor(local_path)
|
||
|
|
||
|
|
||
|
EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
|
||
|
_rnnt_path="models/emformer_rnnt_base_librispeech.pt",
|
||
|
_rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=4097),
|
||
|
_global_stats_path="pipeline-assets/global_stats_rnnt_librispeech.json",
|
||
|
_sp_model_path="pipeline-assets/spm_bpe_4096_librispeech.model",
|
||
|
_right_padding=4,
|
||
|
_blank=4096,
|
||
|
_sample_rate=16000,
|
||
|
_n_fft=400,
|
||
|
_n_mels=80,
|
||
|
_hop_length=160,
|
||
|
_segment_length=16,
|
||
|
_right_context_length=4,
|
||
|
)
|
||
|
EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """ASR pipeline based on Emformer-RNNT,
|
||
|
pretrained on *LibriSpeech* dataset :cite:`7178964`,
|
||
|
capable of performing both streaming and non-streaming inference.
|
||
|
|
||
|
The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
|
||
|
and utilizes weights trained on LibriSpeech using training script ``train.py``
|
||
|
`here <https://github.com/pytorch/audio/tree/main/examples/asr/emformer_rnnt>`__ with default arguments.
|
||
|
|
||
|
Please refer to :py:class:`RNNTBundle` for usage instructions.
|
||
|
"""
|