335 lines
11 KiB
Python
335 lines
11 KiB
Python
|
import os
|
||
|
import re
|
||
|
import sys
|
||
|
from typing import BinaryIO, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torchaudio
|
||
|
|
||
|
from .backend import Backend
|
||
|
from .common import AudioMetaData
|
||
|
|
||
|
InputType = Union[BinaryIO, str, os.PathLike]
|
||
|
|
||
|
|
||
|
def info_audio(
|
||
|
src: InputType,
|
||
|
format: Optional[str],
|
||
|
buffer_size: int = 4096,
|
||
|
) -> AudioMetaData:
|
||
|
s = torchaudio.io.StreamReader(src, format, None, buffer_size)
|
||
|
sinfo = s.get_src_stream_info(s.default_audio_stream)
|
||
|
if sinfo.num_frames == 0:
|
||
|
waveform = _load_audio(s)
|
||
|
num_frames = waveform.size(1)
|
||
|
else:
|
||
|
num_frames = sinfo.num_frames
|
||
|
return AudioMetaData(
|
||
|
int(sinfo.sample_rate),
|
||
|
num_frames,
|
||
|
sinfo.num_channels,
|
||
|
sinfo.bits_per_sample,
|
||
|
sinfo.codec.upper(),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _get_load_filter(
|
||
|
frame_offset: int = 0,
|
||
|
num_frames: int = -1,
|
||
|
convert: bool = True,
|
||
|
) -> Optional[str]:
|
||
|
if frame_offset < 0:
|
||
|
raise RuntimeError("Invalid argument: frame_offset must be non-negative. Found: {}".format(frame_offset))
|
||
|
if num_frames == 0 or num_frames < -1:
|
||
|
raise RuntimeError("Invalid argument: num_frames must be -1 or greater than 0. Found: {}".format(num_frames))
|
||
|
|
||
|
# All default values -> no filter
|
||
|
if frame_offset == 0 and num_frames == -1 and not convert:
|
||
|
return None
|
||
|
# Only convert
|
||
|
aformat = "aformat=sample_fmts=fltp"
|
||
|
if frame_offset == 0 and num_frames == -1 and convert:
|
||
|
return aformat
|
||
|
# At least one of frame_offset or num_frames has non-default value
|
||
|
if num_frames > 0:
|
||
|
atrim = "atrim=start_sample={}:end_sample={}".format(frame_offset, frame_offset + num_frames)
|
||
|
else:
|
||
|
atrim = "atrim=start_sample={}".format(frame_offset)
|
||
|
if not convert:
|
||
|
return atrim
|
||
|
return "{},{}".format(atrim, aformat)
|
||
|
|
||
|
|
||
|
def _load_audio(
|
||
|
s: "torchaudio.io.StreamReader",
|
||
|
filter: Optional[str] = None,
|
||
|
channels_first: bool = True,
|
||
|
) -> torch.Tensor:
|
||
|
s.add_audio_stream(-1, -1, filter_desc=filter)
|
||
|
s.process_all_packets()
|
||
|
chunk = s.pop_chunks()[0]
|
||
|
if chunk is None:
|
||
|
raise RuntimeError("Failed to decode audio.")
|
||
|
waveform = chunk._elem
|
||
|
return waveform.T if channels_first else waveform
|
||
|
|
||
|
|
||
|
def load_audio(
|
||
|
src: InputType,
|
||
|
frame_offset: int = 0,
|
||
|
num_frames: int = -1,
|
||
|
convert: bool = True,
|
||
|
channels_first: bool = True,
|
||
|
format: Optional[str] = None,
|
||
|
buffer_size: int = 4096,
|
||
|
) -> Tuple[torch.Tensor, int]:
|
||
|
if hasattr(src, "read") and format == "vorbis":
|
||
|
format = "ogg"
|
||
|
s = torchaudio.io.StreamReader(src, format, None, buffer_size)
|
||
|
sample_rate = int(s.get_src_stream_info(s.default_audio_stream).sample_rate)
|
||
|
filter = _get_load_filter(frame_offset, num_frames, convert)
|
||
|
waveform = _load_audio(s, filter, channels_first)
|
||
|
return waveform, sample_rate
|
||
|
|
||
|
|
||
|
def _get_sample_format(dtype: torch.dtype) -> str:
|
||
|
dtype_to_format = {
|
||
|
torch.uint8: "u8",
|
||
|
torch.int16: "s16",
|
||
|
torch.int32: "s32",
|
||
|
torch.int64: "s64",
|
||
|
torch.float32: "flt",
|
||
|
torch.float64: "dbl",
|
||
|
}
|
||
|
format = dtype_to_format.get(dtype)
|
||
|
if format is None:
|
||
|
raise ValueError(f"No format found for dtype {dtype}; dtype must be one of {list(dtype_to_format.keys())}.")
|
||
|
return format
|
||
|
|
||
|
|
||
|
def _native_endianness() -> str:
|
||
|
if sys.byteorder == "little":
|
||
|
return "le"
|
||
|
else:
|
||
|
return "be"
|
||
|
|
||
|
|
||
|
def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str:
|
||
|
if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
|
||
|
raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
|
||
|
endianness = _native_endianness()
|
||
|
if not encoding:
|
||
|
if not bits_per_sample:
|
||
|
# default to PCM S16
|
||
|
return f"pcm_s16{endianness}"
|
||
|
if bits_per_sample == 8:
|
||
|
return "pcm_u8"
|
||
|
return f"pcm_s{bits_per_sample}{endianness}"
|
||
|
if encoding == "PCM_S":
|
||
|
if not bits_per_sample:
|
||
|
bits_per_sample = 16
|
||
|
if bits_per_sample == 8:
|
||
|
raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
|
||
|
return f"pcm_s{bits_per_sample}{endianness}"
|
||
|
if encoding == "PCM_U":
|
||
|
if bits_per_sample in (None, 8):
|
||
|
return "pcm_u8"
|
||
|
raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
|
||
|
if encoding == "PCM_F":
|
||
|
if not bits_per_sample:
|
||
|
bits_per_sample = 32
|
||
|
if bits_per_sample in (32, 64):
|
||
|
return f"pcm_f{bits_per_sample}{endianness}"
|
||
|
raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
|
||
|
if encoding == "ULAW":
|
||
|
if bits_per_sample in (None, 8):
|
||
|
return "pcm_mulaw"
|
||
|
raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
|
||
|
if encoding == "ALAW":
|
||
|
if bits_per_sample in (None, 8):
|
||
|
return "pcm_alaw"
|
||
|
raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
|
||
|
raise ValueError(f"WAV encoding {encoding} is not supported.")
|
||
|
|
||
|
|
||
|
def _get_flac_sample_fmt(bps):
|
||
|
if bps is None or bps == 16:
|
||
|
return "s16"
|
||
|
if bps == 24:
|
||
|
return "s32"
|
||
|
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).")
|
||
|
|
||
|
|
||
|
def _parse_save_args(
|
||
|
ext: Optional[str],
|
||
|
format: Optional[str],
|
||
|
encoding: Optional[str],
|
||
|
bps: Optional[int],
|
||
|
):
|
||
|
# torchaudio's save function accepts the followings, which do not 1to1 map
|
||
|
# to FFmpeg.
|
||
|
#
|
||
|
# - format: audio format
|
||
|
# - bits_per_sample: encoder sample format
|
||
|
# - encoding: such as PCM_U8.
|
||
|
#
|
||
|
# In FFmpeg, format is specified with the following three (and more)
|
||
|
#
|
||
|
# - muxer: could be audio format or container format.
|
||
|
# the one we passed to the constructor of StreamWriter
|
||
|
# - encoder: the audio encoder used to encode audio
|
||
|
# - encoder sample format: the format used by encoder to encode audio.
|
||
|
#
|
||
|
# If encoder sample format is different from source sample format, StreamWriter
|
||
|
# will insert a filter automatically.
|
||
|
#
|
||
|
def _type(spec):
|
||
|
# either format is exactly the specified one
|
||
|
# or extension matches to the spec AND there is no format override.
|
||
|
return format == spec or (format is None and ext == spec)
|
||
|
|
||
|
if _type("wav") or _type("amb"):
|
||
|
# wav is special because it supports different encoding through encoders
|
||
|
# each encoder only supports one encoder format
|
||
|
#
|
||
|
# amb format is a special case originated from libsox.
|
||
|
# It is basically a WAV format, with slight modification.
|
||
|
# https://github.com/chirlu/sox/commit/4a4ea33edbca5972a1ed8933cc3512c7302fa67a#diff-39171191a858add9df87f5f210a34a776ac2c026842ae6db6ce97f5e68836795
|
||
|
# It is a format so that decoders will recognize it as ambisonic.
|
||
|
# https://www.ambisonia.com/Members/mleese/file-format-for-b-format/
|
||
|
# FFmpeg does not recognize amb because it is basically a WAV format.
|
||
|
muxer = "wav"
|
||
|
encoder = _get_encoder_for_wav(encoding, bps)
|
||
|
sample_fmt = None
|
||
|
elif _type("vorbis"):
|
||
|
# FFpmeg does not recognize vorbis extension, while libsox used to do.
|
||
|
# For the sake of bakward compatibility, (and the simplicity),
|
||
|
# we support the case where users want to do save("foo.vorbis")
|
||
|
muxer = "ogg"
|
||
|
encoder = "vorbis"
|
||
|
sample_fmt = None
|
||
|
else:
|
||
|
muxer = format
|
||
|
encoder = None
|
||
|
sample_fmt = None
|
||
|
if _type("flac"):
|
||
|
sample_fmt = _get_flac_sample_fmt(bps)
|
||
|
if _type("ogg"):
|
||
|
sample_fmt = _get_flac_sample_fmt(bps)
|
||
|
return muxer, encoder, sample_fmt
|
||
|
|
||
|
|
||
|
def save_audio(
|
||
|
uri: InputType,
|
||
|
src: torch.Tensor,
|
||
|
sample_rate: int,
|
||
|
channels_first: bool = True,
|
||
|
format: Optional[str] = None,
|
||
|
encoding: Optional[str] = None,
|
||
|
bits_per_sample: Optional[int] = None,
|
||
|
buffer_size: int = 4096,
|
||
|
compression: Optional[torchaudio.io.CodecConfig] = None,
|
||
|
) -> None:
|
||
|
ext = None
|
||
|
if hasattr(uri, "write"):
|
||
|
if format is None:
|
||
|
raise RuntimeError("'format' is required when saving to file object.")
|
||
|
else:
|
||
|
uri = os.path.normpath(uri)
|
||
|
if tokens := str(uri).split(".")[1:]:
|
||
|
ext = tokens[-1].lower()
|
||
|
|
||
|
muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)
|
||
|
|
||
|
if channels_first:
|
||
|
src = src.T
|
||
|
|
||
|
s = torchaudio.io.StreamWriter(uri, format=muxer, buffer_size=buffer_size)
|
||
|
s.add_audio_stream(
|
||
|
sample_rate,
|
||
|
num_channels=src.size(-1),
|
||
|
format=_get_sample_format(src.dtype),
|
||
|
encoder=encoder,
|
||
|
encoder_format=enc_fmt,
|
||
|
codec_config=compression,
|
||
|
)
|
||
|
with s.open():
|
||
|
s.write_audio_chunk(0, src)
|
||
|
|
||
|
|
||
|
def _map_encoding(encoding: str) -> str:
|
||
|
for dst in ["PCM_S", "PCM_U", "PCM_F"]:
|
||
|
if dst in encoding:
|
||
|
return dst
|
||
|
if encoding == "PCM_MULAW":
|
||
|
return "ULAW"
|
||
|
elif encoding == "PCM_ALAW":
|
||
|
return "ALAW"
|
||
|
return encoding
|
||
|
|
||
|
|
||
|
def _get_bits_per_sample(encoding: str, bits_per_sample: int) -> str:
|
||
|
if m := re.search(r"PCM_\w(\d+)\w*", encoding):
|
||
|
return int(m.group(1))
|
||
|
elif encoding in ["PCM_ALAW", "PCM_MULAW"]:
|
||
|
return 8
|
||
|
return bits_per_sample
|
||
|
|
||
|
|
||
|
class FFmpegBackend(Backend):
|
||
|
@staticmethod
|
||
|
def info(uri: InputType, format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
|
||
|
metadata = info_audio(uri, format, buffer_size)
|
||
|
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
|
||
|
metadata.encoding = _map_encoding(metadata.encoding)
|
||
|
return metadata
|
||
|
|
||
|
@staticmethod
|
||
|
def load(
|
||
|
uri: InputType,
|
||
|
frame_offset: int = 0,
|
||
|
num_frames: int = -1,
|
||
|
normalize: bool = True,
|
||
|
channels_first: bool = True,
|
||
|
format: Optional[str] = None,
|
||
|
buffer_size: int = 4096,
|
||
|
) -> Tuple[torch.Tensor, int]:
|
||
|
return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)
|
||
|
|
||
|
@staticmethod
|
||
|
def save(
|
||
|
uri: InputType,
|
||
|
src: torch.Tensor,
|
||
|
sample_rate: int,
|
||
|
channels_first: bool = True,
|
||
|
format: Optional[str] = None,
|
||
|
encoding: Optional[str] = None,
|
||
|
bits_per_sample: Optional[int] = None,
|
||
|
buffer_size: int = 4096,
|
||
|
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
|
||
|
) -> None:
|
||
|
if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
|
||
|
raise ValueError(
|
||
|
"FFmpeg backend expects non-`None` value for argument `compression` to be of ",
|
||
|
f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
|
||
|
)
|
||
|
save_audio(
|
||
|
uri,
|
||
|
src,
|
||
|
sample_rate,
|
||
|
channels_first,
|
||
|
format,
|
||
|
encoding,
|
||
|
bits_per_sample,
|
||
|
buffer_size,
|
||
|
compression,
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def can_decode(uri: InputType, format: Optional[str]) -> bool:
|
||
|
return True
|
||
|
|
||
|
@staticmethod
|
||
|
def can_encode(uri: InputType, format: Optional[str]) -> bool:
|
||
|
return True
|