340 lines
12 KiB
Python
340 lines
12 KiB
Python
|
from typing import Callable, Dict, List, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
from torchaudio.models import RNNT
|
||
|
|
||
|
|
||
|
__all__ = ["Hypothesis", "RNNTBeamSearch"]
|
||
|
|
||
|
|
||
|
Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
|
||
|
Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
|
||
|
represented as tuple of (tokens, prediction network output, prediction network state, score).
|
||
|
"""
|
||
|
|
||
|
|
||
|
def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
|
||
|
return hypo[0]
|
||
|
|
||
|
|
||
|
def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
|
||
|
return hypo[1]
|
||
|
|
||
|
|
||
|
def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
|
||
|
return hypo[2]
|
||
|
|
||
|
|
||
|
def _get_hypo_score(hypo: Hypothesis) -> float:
|
||
|
return hypo[3]
|
||
|
|
||
|
|
||
|
def _get_hypo_key(hypo: Hypothesis) -> str:
|
||
|
return str(hypo[0])
|
||
|
|
||
|
|
||
|
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
|
||
|
states: List[List[torch.Tensor]] = []
|
||
|
for i in range(len(_get_hypo_state(hypos[0]))):
|
||
|
batched_state_components: List[torch.Tensor] = []
|
||
|
for j in range(len(_get_hypo_state(hypos[0])[i])):
|
||
|
batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
|
||
|
states.append(batched_state_components)
|
||
|
return states
|
||
|
|
||
|
|
||
|
def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
|
||
|
idx_tensor = torch.tensor([idx], device=device)
|
||
|
return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
|
||
|
|
||
|
|
||
|
def _default_hypo_sort_key(hypo: Hypothesis) -> float:
|
||
|
return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
|
||
|
|
||
|
|
||
|
def _compute_updated_scores(
|
||
|
hypos: List[Hypothesis],
|
||
|
next_token_probs: torch.Tensor,
|
||
|
beam_width: int,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
|
hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
|
||
|
nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
|
||
|
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
|
||
|
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
|
||
|
nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
|
||
|
return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
|
||
|
|
||
|
|
||
|
def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
|
||
|
for i, elem in enumerate(hypo_list):
|
||
|
if _get_hypo_key(hypo) == _get_hypo_key(elem):
|
||
|
del hypo_list[i]
|
||
|
break
|
||
|
|
||
|
|
||
|
class RNNTBeamSearch(torch.nn.Module):
|
||
|
r"""Beam search decoder for RNN-T model.
|
||
|
|
||
|
See Also:
|
||
|
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model.
|
||
|
|
||
|
Args:
|
||
|
model (RNNT): RNN-T model to use.
|
||
|
blank (int): index of blank token in vocabulary.
|
||
|
temperature (float, optional): temperature to apply to joint network output.
|
||
|
Larger values yield more uniform samples. (Default: 1.0)
|
||
|
hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
|
||
|
for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
|
||
|
hypothesis score normalized by token sequence length. (Default: None)
|
||
|
step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
model: RNNT,
|
||
|
blank: int,
|
||
|
temperature: float = 1.0,
|
||
|
hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
|
||
|
step_max_tokens: int = 100,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
self.model = model
|
||
|
self.blank = blank
|
||
|
self.temperature = temperature
|
||
|
|
||
|
if hypo_sort_key is None:
|
||
|
self.hypo_sort_key = _default_hypo_sort_key
|
||
|
else:
|
||
|
self.hypo_sort_key = hypo_sort_key
|
||
|
|
||
|
self.step_max_tokens = step_max_tokens
|
||
|
|
||
|
def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
|
||
|
token = self.blank
|
||
|
state = None
|
||
|
|
||
|
one_tensor = torch.tensor([1], device=device)
|
||
|
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
|
||
|
init_hypo = (
|
||
|
[token],
|
||
|
pred_out[0].detach(),
|
||
|
pred_state,
|
||
|
0.0,
|
||
|
)
|
||
|
return [init_hypo]
|
||
|
|
||
|
def _gen_next_token_probs(
|
||
|
self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
|
||
|
) -> torch.Tensor:
|
||
|
one_tensor = torch.tensor([1], device=device)
|
||
|
predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
|
||
|
joined_out, _, _ = self.model.join(
|
||
|
enc_out,
|
||
|
one_tensor,
|
||
|
predictor_out,
|
||
|
torch.tensor([1] * len(hypos), device=device),
|
||
|
) # [beam_width, 1, 1, num_tokens]
|
||
|
joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
|
||
|
return joined_out[:, 0, 0]
|
||
|
|
||
|
def _gen_b_hypos(
|
||
|
self,
|
||
|
b_hypos: List[Hypothesis],
|
||
|
a_hypos: List[Hypothesis],
|
||
|
next_token_probs: torch.Tensor,
|
||
|
key_to_b_hypo: Dict[str, Hypothesis],
|
||
|
) -> List[Hypothesis]:
|
||
|
for i in range(len(a_hypos)):
|
||
|
h_a = a_hypos[i]
|
||
|
append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
|
||
|
if _get_hypo_key(h_a) in key_to_b_hypo:
|
||
|
h_b = key_to_b_hypo[_get_hypo_key(h_a)]
|
||
|
_remove_hypo(h_b, b_hypos)
|
||
|
score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
|
||
|
else:
|
||
|
score = float(append_blank_score)
|
||
|
h_b = (
|
||
|
_get_hypo_tokens(h_a),
|
||
|
_get_hypo_predictor_out(h_a),
|
||
|
_get_hypo_state(h_a),
|
||
|
score,
|
||
|
)
|
||
|
b_hypos.append(h_b)
|
||
|
key_to_b_hypo[_get_hypo_key(h_b)] = h_b
|
||
|
_, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
|
||
|
return [b_hypos[idx] for idx in sorted_idx]
|
||
|
|
||
|
def _gen_a_hypos(
|
||
|
self,
|
||
|
a_hypos: List[Hypothesis],
|
||
|
b_hypos: List[Hypothesis],
|
||
|
next_token_probs: torch.Tensor,
|
||
|
t: int,
|
||
|
beam_width: int,
|
||
|
device: torch.device,
|
||
|
) -> List[Hypothesis]:
|
||
|
(
|
||
|
nonblank_nbest_scores,
|
||
|
nonblank_nbest_hypo_idx,
|
||
|
nonblank_nbest_token,
|
||
|
) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
|
||
|
|
||
|
if len(b_hypos) < beam_width:
|
||
|
b_nbest_score = -float("inf")
|
||
|
else:
|
||
|
b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
|
||
|
|
||
|
base_hypos: List[Hypothesis] = []
|
||
|
new_tokens: List[int] = []
|
||
|
new_scores: List[float] = []
|
||
|
for i in range(beam_width):
|
||
|
score = float(nonblank_nbest_scores[i])
|
||
|
if score > b_nbest_score:
|
||
|
a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
|
||
|
base_hypos.append(a_hypos[a_hypo_idx])
|
||
|
new_tokens.append(int(nonblank_nbest_token[i]))
|
||
|
new_scores.append(score)
|
||
|
|
||
|
if base_hypos:
|
||
|
new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
|
||
|
else:
|
||
|
new_hypos: List[Hypothesis] = []
|
||
|
|
||
|
return new_hypos
|
||
|
|
||
|
def _gen_new_hypos(
|
||
|
self,
|
||
|
base_hypos: List[Hypothesis],
|
||
|
tokens: List[int],
|
||
|
scores: List[float],
|
||
|
t: int,
|
||
|
device: torch.device,
|
||
|
) -> List[Hypothesis]:
|
||
|
tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
|
||
|
states = _batch_state(base_hypos)
|
||
|
pred_out, _, pred_states = self.model.predict(
|
||
|
tgt_tokens,
|
||
|
torch.tensor([1] * len(base_hypos), device=device),
|
||
|
states,
|
||
|
)
|
||
|
new_hypos: List[Hypothesis] = []
|
||
|
for i, h_a in enumerate(base_hypos):
|
||
|
new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
|
||
|
new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
|
||
|
return new_hypos
|
||
|
|
||
|
def _search(
|
||
|
self,
|
||
|
enc_out: torch.Tensor,
|
||
|
hypo: Optional[List[Hypothesis]],
|
||
|
beam_width: int,
|
||
|
) -> List[Hypothesis]:
|
||
|
n_time_steps = enc_out.shape[1]
|
||
|
device = enc_out.device
|
||
|
|
||
|
a_hypos: List[Hypothesis] = []
|
||
|
b_hypos = self._init_b_hypos(device) if hypo is None else hypo
|
||
|
for t in range(n_time_steps):
|
||
|
a_hypos = b_hypos
|
||
|
b_hypos = torch.jit.annotate(List[Hypothesis], [])
|
||
|
key_to_b_hypo: Dict[str, Hypothesis] = {}
|
||
|
symbols_current_t = 0
|
||
|
|
||
|
while a_hypos:
|
||
|
next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
|
||
|
next_token_probs = next_token_probs.cpu()
|
||
|
b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
|
||
|
|
||
|
if symbols_current_t == self.step_max_tokens:
|
||
|
break
|
||
|
|
||
|
a_hypos = self._gen_a_hypos(
|
||
|
a_hypos,
|
||
|
b_hypos,
|
||
|
next_token_probs,
|
||
|
t,
|
||
|
beam_width,
|
||
|
device,
|
||
|
)
|
||
|
if a_hypos:
|
||
|
symbols_current_t += 1
|
||
|
|
||
|
_, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
|
||
|
b_hypos = [b_hypos[idx] for idx in sorted_idx]
|
||
|
|
||
|
return b_hypos
|
||
|
|
||
|
def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
|
||
|
r"""Performs beam search for the given input sequence.
|
||
|
|
||
|
T: number of frames;
|
||
|
D: feature dimension of each frame.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
|
||
|
length (torch.Tensor): number of valid frames in input
|
||
|
sequence, with shape () or (1,).
|
||
|
beam_width (int): beam size to use during search.
|
||
|
|
||
|
Returns:
|
||
|
List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
|
||
|
"""
|
||
|
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
|
||
|
raise ValueError("input must be of shape (T, D) or (1, T, D)")
|
||
|
if input.dim() == 2:
|
||
|
input = input.unsqueeze(0)
|
||
|
|
||
|
if length.shape != () and length.shape != (1,):
|
||
|
raise ValueError("length must be of shape () or (1,)")
|
||
|
if length.dim() == 0:
|
||
|
length = length.unsqueeze(0)
|
||
|
|
||
|
enc_out, _ = self.model.transcribe(input, length)
|
||
|
return self._search(enc_out, None, beam_width)
|
||
|
|
||
|
@torch.jit.export
|
||
|
def infer(
|
||
|
self,
|
||
|
input: torch.Tensor,
|
||
|
length: torch.Tensor,
|
||
|
beam_width: int,
|
||
|
state: Optional[List[List[torch.Tensor]]] = None,
|
||
|
hypothesis: Optional[List[Hypothesis]] = None,
|
||
|
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
|
||
|
r"""Performs beam search for the given input sequence in streaming mode.
|
||
|
|
||
|
T: number of frames;
|
||
|
D: feature dimension of each frame.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
|
||
|
length (torch.Tensor): number of valid frames in input
|
||
|
sequence, with shape () or (1,).
|
||
|
beam_width (int): beam size to use during search.
|
||
|
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
||
|
representing transcription network internal state generated in preceding
|
||
|
invocation. (Default: ``None``)
|
||
|
hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed
|
||
|
search with. (Default: ``None``)
|
||
|
|
||
|
Returns:
|
||
|
(List[Hypothesis], List[List[torch.Tensor]]):
|
||
|
List[Hypothesis]
|
||
|
top-``beam_width`` hypotheses found by beam search.
|
||
|
List[List[torch.Tensor]]
|
||
|
list of lists of tensors representing transcription network
|
||
|
internal state generated in current invocation.
|
||
|
"""
|
||
|
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
|
||
|
raise ValueError("input must be of shape (T, D) or (1, T, D)")
|
||
|
if input.dim() == 2:
|
||
|
input = input.unsqueeze(0)
|
||
|
|
||
|
if length.shape != () and length.shape != (1,):
|
||
|
raise ValueError("length must be of shape () or (1,)")
|
||
|
if length.dim() == 0:
|
||
|
length = length.unsqueeze(0)
|
||
|
|
||
|
enc_out, _, state = self.model.transcribe_streaming(input, length, state)
|
||
|
return self._search(enc_out, hypothesis, beam_width), state
|