Traktor/myenv/Lib/site-packages/torchaudio/_backend/ffmpeg.py

335 lines
11 KiB
Python
Raw Permalink Normal View History

2024-05-26 05:12:46 +02:00
import os
import re
import sys
from typing import BinaryIO, Optional, Tuple, Union
import torch
import torchaudio
from .backend import Backend
from .common import AudioMetaData
InputType = Union[BinaryIO, str, os.PathLike]
def info_audio(
src: InputType,
format: Optional[str],
buffer_size: int = 4096,
) -> AudioMetaData:
s = torchaudio.io.StreamReader(src, format, None, buffer_size)
sinfo = s.get_src_stream_info(s.default_audio_stream)
if sinfo.num_frames == 0:
waveform = _load_audio(s)
num_frames = waveform.size(1)
else:
num_frames = sinfo.num_frames
return AudioMetaData(
int(sinfo.sample_rate),
num_frames,
sinfo.num_channels,
sinfo.bits_per_sample,
sinfo.codec.upper(),
)
def _get_load_filter(
frame_offset: int = 0,
num_frames: int = -1,
convert: bool = True,
) -> Optional[str]:
if frame_offset < 0:
raise RuntimeError("Invalid argument: frame_offset must be non-negative. Found: {}".format(frame_offset))
if num_frames == 0 or num_frames < -1:
raise RuntimeError("Invalid argument: num_frames must be -1 or greater than 0. Found: {}".format(num_frames))
# All default values -> no filter
if frame_offset == 0 and num_frames == -1 and not convert:
return None
# Only convert
aformat = "aformat=sample_fmts=fltp"
if frame_offset == 0 and num_frames == -1 and convert:
return aformat
# At least one of frame_offset or num_frames has non-default value
if num_frames > 0:
atrim = "atrim=start_sample={}:end_sample={}".format(frame_offset, frame_offset + num_frames)
else:
atrim = "atrim=start_sample={}".format(frame_offset)
if not convert:
return atrim
return "{},{}".format(atrim, aformat)
def _load_audio(
s: "torchaudio.io.StreamReader",
filter: Optional[str] = None,
channels_first: bool = True,
) -> torch.Tensor:
s.add_audio_stream(-1, -1, filter_desc=filter)
s.process_all_packets()
chunk = s.pop_chunks()[0]
if chunk is None:
raise RuntimeError("Failed to decode audio.")
waveform = chunk._elem
return waveform.T if channels_first else waveform
def load_audio(
src: InputType,
frame_offset: int = 0,
num_frames: int = -1,
convert: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
if hasattr(src, "read") and format == "vorbis":
format = "ogg"
s = torchaudio.io.StreamReader(src, format, None, buffer_size)
sample_rate = int(s.get_src_stream_info(s.default_audio_stream).sample_rate)
filter = _get_load_filter(frame_offset, num_frames, convert)
waveform = _load_audio(s, filter, channels_first)
return waveform, sample_rate
def _get_sample_format(dtype: torch.dtype) -> str:
dtype_to_format = {
torch.uint8: "u8",
torch.int16: "s16",
torch.int32: "s32",
torch.int64: "s64",
torch.float32: "flt",
torch.float64: "dbl",
}
format = dtype_to_format.get(dtype)
if format is None:
raise ValueError(f"No format found for dtype {dtype}; dtype must be one of {list(dtype_to_format.keys())}.")
return format
def _native_endianness() -> str:
if sys.byteorder == "little":
return "le"
else:
return "be"
def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str:
if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
endianness = _native_endianness()
if not encoding:
if not bits_per_sample:
# default to PCM S16
return f"pcm_s16{endianness}"
if bits_per_sample == 8:
return "pcm_u8"
return f"pcm_s{bits_per_sample}{endianness}"
if encoding == "PCM_S":
if not bits_per_sample:
bits_per_sample = 16
if bits_per_sample == 8:
raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
return f"pcm_s{bits_per_sample}{endianness}"
if encoding == "PCM_U":
if bits_per_sample in (None, 8):
return "pcm_u8"
raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
if encoding == "PCM_F":
if not bits_per_sample:
bits_per_sample = 32
if bits_per_sample in (32, 64):
return f"pcm_f{bits_per_sample}{endianness}"
raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "pcm_mulaw"
raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
if encoding == "ALAW":
if bits_per_sample in (None, 8):
return "pcm_alaw"
raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
raise ValueError(f"WAV encoding {encoding} is not supported.")
def _get_flac_sample_fmt(bps):
if bps is None or bps == 16:
return "s16"
if bps == 24:
return "s32"
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).")
def _parse_save_args(
ext: Optional[str],
format: Optional[str],
encoding: Optional[str],
bps: Optional[int],
):
# torchaudio's save function accepts the followings, which do not 1to1 map
# to FFmpeg.
#
# - format: audio format
# - bits_per_sample: encoder sample format
# - encoding: such as PCM_U8.
#
# In FFmpeg, format is specified with the following three (and more)
#
# - muxer: could be audio format or container format.
# the one we passed to the constructor of StreamWriter
# - encoder: the audio encoder used to encode audio
# - encoder sample format: the format used by encoder to encode audio.
#
# If encoder sample format is different from source sample format, StreamWriter
# will insert a filter automatically.
#
def _type(spec):
# either format is exactly the specified one
# or extension matches to the spec AND there is no format override.
return format == spec or (format is None and ext == spec)
if _type("wav") or _type("amb"):
# wav is special because it supports different encoding through encoders
# each encoder only supports one encoder format
#
# amb format is a special case originated from libsox.
# It is basically a WAV format, with slight modification.
# https://github.com/chirlu/sox/commit/4a4ea33edbca5972a1ed8933cc3512c7302fa67a#diff-39171191a858add9df87f5f210a34a776ac2c026842ae6db6ce97f5e68836795
# It is a format so that decoders will recognize it as ambisonic.
# https://www.ambisonia.com/Members/mleese/file-format-for-b-format/
# FFmpeg does not recognize amb because it is basically a WAV format.
muxer = "wav"
encoder = _get_encoder_for_wav(encoding, bps)
sample_fmt = None
elif _type("vorbis"):
# FFpmeg does not recognize vorbis extension, while libsox used to do.
# For the sake of bakward compatibility, (and the simplicity),
# we support the case where users want to do save("foo.vorbis")
muxer = "ogg"
encoder = "vorbis"
sample_fmt = None
else:
muxer = format
encoder = None
sample_fmt = None
if _type("flac"):
sample_fmt = _get_flac_sample_fmt(bps)
if _type("ogg"):
sample_fmt = _get_flac_sample_fmt(bps)
return muxer, encoder, sample_fmt
def save_audio(
uri: InputType,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[torchaudio.io.CodecConfig] = None,
) -> None:
ext = None
if hasattr(uri, "write"):
if format is None:
raise RuntimeError("'format' is required when saving to file object.")
else:
uri = os.path.normpath(uri)
if tokens := str(uri).split(".")[1:]:
ext = tokens[-1].lower()
muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)
if channels_first:
src = src.T
s = torchaudio.io.StreamWriter(uri, format=muxer, buffer_size=buffer_size)
s.add_audio_stream(
sample_rate,
num_channels=src.size(-1),
format=_get_sample_format(src.dtype),
encoder=encoder,
encoder_format=enc_fmt,
codec_config=compression,
)
with s.open():
s.write_audio_chunk(0, src)
def _map_encoding(encoding: str) -> str:
for dst in ["PCM_S", "PCM_U", "PCM_F"]:
if dst in encoding:
return dst
if encoding == "PCM_MULAW":
return "ULAW"
elif encoding == "PCM_ALAW":
return "ALAW"
return encoding
def _get_bits_per_sample(encoding: str, bits_per_sample: int) -> str:
if m := re.search(r"PCM_\w(\d+)\w*", encoding):
return int(m.group(1))
elif encoding in ["PCM_ALAW", "PCM_MULAW"]:
return 8
return bits_per_sample
class FFmpegBackend(Backend):
@staticmethod
def info(uri: InputType, format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
metadata = info_audio(uri, format, buffer_size)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding)
return metadata
@staticmethod
def load(
uri: InputType,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)
@staticmethod
def save(
uri: InputType,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None:
if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
raise ValueError(
"FFmpeg backend expects non-`None` value for argument `compression` to be of ",
f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
)
save_audio(
uri,
src,
sample_rate,
channels_first,
format,
encoding,
bits_per_sample,
buffer_size,
compression,
)
@staticmethod
def can_decode(uri: InputType, format: Optional[str]) -> bool:
return True
@staticmethod
def can_encode(uri: InputType, format: Optional[str]) -> bool:
return True