mirror of https://github.com/rhasspy/piper
Initial check in of Python training code
parent
344b483904
commit
a6b2d2e69c
@ -1 +1,3 @@
|
||||
- /build/
|
||||
- /src/python/.venv/
|
||||
- /local/
|
||||
|
@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env bash
|
||||
set -eo pipefail
|
||||
|
||||
this_dir="$( cd "$( dirname "$0" )" && pwd )"
|
||||
|
||||
if [ -d "${this_dir}/.venv" ]; then
|
||||
source "${this_dir}/.venv/bin/activate"
|
||||
fi
|
||||
|
||||
cd "${this_dir}/larynx_train/vits/monotonic_align"
|
||||
mkdir -p monotonic_align
|
||||
cythonize -i core.pyx
|
||||
mv core*.so monotonic_align/
|
@ -0,0 +1,11 @@
|
||||
.DS_Store
|
||||
.idea
|
||||
*.log
|
||||
tmp/
|
||||
|
||||
*.py[cod]
|
||||
*.egg
|
||||
build
|
||||
htmlcov
|
||||
|
||||
.venv/
|
@ -0,0 +1,6 @@
|
||||
[settings]
|
||||
multi_line_output=3
|
||||
include_trailing_comma=True
|
||||
force_grid_wrap=0
|
||||
use_parentheses=True
|
||||
line_length=88
|
@ -0,0 +1 @@
|
||||
0.1.0
|
@ -0,0 +1,61 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from .vits.lightning import VitsModel
|
||||
|
||||
_LOGGER = logging.getLogger(__package__)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dataset-dir", required=True, help="Path to pre-processed dataset directory"
|
||||
)
|
||||
Trainer.add_argparse_args(parser)
|
||||
VitsModel.add_model_specific_args(parser)
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
args = parser.parse_args()
|
||||
_LOGGER.debug(args)
|
||||
|
||||
args.dataset_dir = Path(args.dataset_dir)
|
||||
if not args.default_root_dir:
|
||||
args.default_root_dir = args.dataset_dir
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
config_path = args.dataset_dir / "config.json"
|
||||
dataset_path = args.dataset_dir / "dataset.jsonl"
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as config_file:
|
||||
# See preprocess.py for format
|
||||
config = json.load(config_file)
|
||||
num_symbols = int(config["num_symbols"])
|
||||
num_speakers = int(config["num_speakers"])
|
||||
sample_rate = int(config["audio"]["sample_rate"])
|
||||
|
||||
trainer = Trainer.from_argparse_args(args)
|
||||
dict_args = vars(args)
|
||||
model = VitsModel(
|
||||
num_symbols=num_symbols,
|
||||
num_speakers=num_speakers,
|
||||
sample_rate=sample_rate,
|
||||
dataset=[dataset_path],
|
||||
**dict_args
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,20 @@
|
||||
"""Shared access to package resources"""
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import importlib.resources
|
||||
|
||||
files = importlib.resources.files
|
||||
except (ImportError, AttributeError):
|
||||
# Backport for Python < 3.9
|
||||
import importlib_resources # type: ignore
|
||||
|
||||
files = importlib_resources.files
|
||||
|
||||
_PACKAGE = "larynx_train"
|
||||
_DIR = Path(typing.cast(os.PathLike, files(_PACKAGE)))
|
||||
|
||||
__version__ = (_DIR / "VERSION").read_text(encoding="utf-8").strip()
|
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .vits.lightning import VitsModel
|
||||
|
||||
_LOGGER = logging.getLogger("mimic3_train.export_onnx")
|
||||
|
||||
OPSET_VERSION = 15
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
torch.manual_seed(12345)
|
||||
|
||||
parser = argparse.ArgumentParser(prog="mimic3_train.export_onnx")
|
||||
parser.add_argument("checkpoint", help="Path to model checkpoint (.ckpt)")
|
||||
parser.add_argument("output", help="Path to output model (.onnx)")
|
||||
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="Print DEBUG messages to the console"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
_LOGGER.debug(args)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
args.checkpoint = Path(args.checkpoint)
|
||||
args.output = Path(args.output)
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = VitsModel.load_from_checkpoint(args.checkpoint)
|
||||
model_g = model.model_g
|
||||
|
||||
num_symbols = model_g.n_vocab
|
||||
num_speakers = model_g.n_speakers
|
||||
|
||||
# Inference only
|
||||
model_g.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
model_g.dec.remove_weight_norm()
|
||||
|
||||
# old_forward = model_g.infer
|
||||
|
||||
def infer_forward(text, text_lengths, scales, sid=None):
|
||||
noise_scale = scales[0]
|
||||
length_scale = scales[1]
|
||||
noise_scale_w = scales[2]
|
||||
audio = model_g.infer(
|
||||
text,
|
||||
text_lengths,
|
||||
noise_scale=noise_scale,
|
||||
length_scale=length_scale,
|
||||
noise_scale_w=noise_scale_w,
|
||||
sid=sid,
|
||||
)[0].unsqueeze(1)
|
||||
|
||||
return audio
|
||||
|
||||
model_g.forward = infer_forward
|
||||
|
||||
sequences = torch.randint(low=0, high=num_symbols, size=(1, 50), dtype=torch.long)
|
||||
sequence_lengths = torch.LongTensor([sequences.size(1)])
|
||||
|
||||
sid: Optional[int] = None
|
||||
if num_speakers > 1:
|
||||
sid = torch.LongTensor([0])
|
||||
|
||||
# noise, noise_w, length
|
||||
scales = torch.FloatTensor([0.667, 1.0, 0.8])
|
||||
|
||||
dummy_input = (sequences, sequence_lengths, scales, sid)
|
||||
|
||||
# Export
|
||||
torch.onnx.export(
|
||||
model=model_g,
|
||||
args=dummy_input,
|
||||
f=str(args.output),
|
||||
verbose=True,
|
||||
opset_version=OPSET_VERSION,
|
||||
input_names=["input", "input_lengths", "scales", "sid"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={
|
||||
"input": {0: "batch_size", 1: "phonemes"},
|
||||
"input_lengths": {0: "batch_size"},
|
||||
"output": {0: "batch_size", 1: "time"},
|
||||
},
|
||||
)
|
||||
|
||||
_LOGGER.info("Exported model to %s", args.output)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from .vits.lightning import VitsModel
|
||||
from .vits.utils import audio_float_to_int16
|
||||
from .vits.wavfile import write as write_wav
|
||||
|
||||
_LOGGER = logging.getLogger("mimic3_train.infer")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = argparse.ArgumentParser(prog="mimic3_train.infer")
|
||||
parser.add_argument(
|
||||
"--checkpoint", required=True, help="Path to model checkpoint (.ckpt)"
|
||||
)
|
||||
parser.add_argument("--output-dir", required=True, help="Path to write WAV files")
|
||||
parser.add_argument("--sample-rate", type=int, default=22050)
|
||||
args = parser.parse_args()
|
||||
|
||||
args.output_dir = Path(args.output_dir)
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = VitsModel.load_from_checkpoint(args.checkpoint)
|
||||
|
||||
# Inference only
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
model.model_g.dec.remove_weight_norm()
|
||||
|
||||
for i, line in enumerate(sys.stdin):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
utt = json.loads(line)
|
||||
# utt_id = utt["id"]
|
||||
utt_id = str(i)
|
||||
phoneme_ids = utt["phoneme_ids"]
|
||||
|
||||
text = torch.LongTensor(phoneme_ids).unsqueeze(0)
|
||||
text_lengths = torch.LongTensor([len(phoneme_ids)])
|
||||
scales = [0.667, 1.0, 0.8]
|
||||
|
||||
start_time = time.perf_counter()
|
||||
audio = model(text, text_lengths, scales).detach().numpy()
|
||||
audio = audio_float_to_int16(audio)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
audio_duration_sec = audio.shape[-1] / args.sample_rate
|
||||
infer_sec = end_time - start_time
|
||||
real_time_factor = (
|
||||
infer_sec / audio_duration_sec if audio_duration_sec > 0 else 0.0
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Real-time factor for %s: %0.2f (infer=%0.2f sec, audio=%0.2f sec)",
|
||||
i + 1,
|
||||
real_time_factor,
|
||||
infer_sec,
|
||||
audio_duration_sec,
|
||||
)
|
||||
|
||||
output_path = args.output_dir / f"{utt_id}.wav"
|
||||
write_wav(str(output_path), args.sample_rate, audio)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
||||
from .vits.utils import audio_float_to_int16
|
||||
from .vits.wavfile import write as write_wav
|
||||
|
||||
_LOGGER = logging.getLogger("mimic3_train.infer_onnx")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = argparse.ArgumentParser(prog="mimic3_train.infer_onnx")
|
||||
parser.add_argument("--model", required=True, help="Path to model (.onnx)")
|
||||
parser.add_argument("--output-dir", required=True, help="Path to write WAV files")
|
||||
parser.add_argument("--sample-rate", type=int, default=22050)
|
||||
parser.add_argument("--noise-scale", type=float, default=0.667)
|
||||
parser.add_argument("--noise-scale-w", type=float, default=0.8)
|
||||
parser.add_argument("--length-scale", type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
args.output_dir = Path(args.output_dir)
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
_LOGGER.debug("Loading model from %s", args.model)
|
||||
model = onnxruntime.InferenceSession(str(args.model), sess_options=sess_options)
|
||||
_LOGGER.info("Loaded model from %s", args.model)
|
||||
|
||||
text_empty = np.zeros((1, 300), dtype=np.int64)
|
||||
text_lengths_empty = np.array([text_empty.shape[1]], dtype=np.int64)
|
||||
scales = np.array(
|
||||
[args.noise_scale, args.length_scale, args.noise_scale_w],
|
||||
dtype=np.float32,
|
||||
)
|
||||
bias_audio = model.run(
|
||||
None,
|
||||
{"input": text_empty, "input_lengths": text_lengths_empty, "scales": scales},
|
||||
)[0].squeeze((0, 1))
|
||||
bias_spec, _ = transform(bias_audio)
|
||||
|
||||
for i, line in enumerate(sys.stdin):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
utt = json.loads(line)
|
||||
# utt_id = utt["id"]
|
||||
utt_id = str(i)
|
||||
phoneme_ids = utt["phoneme_ids"]
|
||||
|
||||
text = np.expand_dims(np.array(phoneme_ids, dtype=np.int64), 0)
|
||||
text_lengths = np.array([text.shape[1]], dtype=np.int64)
|
||||
scales = np.array(
|
||||
[args.noise_scale, args.length_scale, args.noise_scale_w],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
audio = model.run(
|
||||
None, {"input": text, "input_lengths": text_lengths, "scales": scales}
|
||||
)[0].squeeze((0, 1))
|
||||
audio = denoise(audio, bias_spec, 10)
|
||||
audio = audio_float_to_int16(audio.squeeze())
|
||||
end_time = time.perf_counter()
|
||||
|
||||
audio_duration_sec = audio.shape[-1] / args.sample_rate
|
||||
infer_sec = end_time - start_time
|
||||
real_time_factor = (
|
||||
infer_sec / audio_duration_sec if audio_duration_sec > 0 else 0.0
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Real-time factor for %s: %0.2f (infer=%0.2f sec, audio=%0.2f sec)",
|
||||
i + 1,
|
||||
real_time_factor,
|
||||
infer_sec,
|
||||
audio_duration_sec,
|
||||
)
|
||||
|
||||
output_path = args.output_dir / f"{utt_id}.wav"
|
||||
write_wav(str(output_path), args.sample_rate, audio)
|
||||
|
||||
|
||||
def denoise(
|
||||
audio: np.ndarray, bias_spec: np.ndarray, denoiser_strength: float
|
||||
) -> np.ndarray:
|
||||
audio_spec, audio_angles = transform(audio)
|
||||
|
||||
a = bias_spec.shape[-1]
|
||||
b = audio_spec.shape[-1]
|
||||
repeats = max(1, math.ceil(b / a))
|
||||
bias_spec_repeat = np.repeat(bias_spec, repeats, axis=-1)[..., :b]
|
||||
|
||||
audio_spec_denoised = audio_spec - (bias_spec_repeat * denoiser_strength)
|
||||
audio_spec_denoised = np.clip(audio_spec_denoised, a_min=0.0, a_max=None)
|
||||
audio_denoised = inverse(audio_spec_denoised, audio_angles)
|
||||
|
||||
return audio_denoised
|
||||
|
||||
|
||||
def stft(x, fft_size, hopsamp):
|
||||
"""Compute and return the STFT of the supplied time domain signal x.
|
||||
Args:
|
||||
x (1-dim Numpy array): A time domain signal.
|
||||
fft_size (int): FFT size. Should be a power of 2, otherwise DFT will be used.
|
||||
hopsamp (int):
|
||||
Returns:
|
||||
The STFT. The rows are the time slices and columns are the frequency bins.
|
||||
"""
|
||||
window = np.hanning(fft_size)
|
||||
fft_size = int(fft_size)
|
||||
hopsamp = int(hopsamp)
|
||||
return np.array(
|
||||
[
|
||||
np.fft.rfft(window * x[i : i + fft_size])
|
||||
for i in range(0, len(x) - fft_size, hopsamp)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def istft(X, fft_size, hopsamp):
|
||||
"""Invert a STFT into a time domain signal.
|
||||
Args:
|
||||
X (2-dim Numpy array): Input spectrogram. The rows are the time slices and columns are the frequency bins.
|
||||
fft_size (int):
|
||||
hopsamp (int): The hop size, in samples.
|
||||
Returns:
|
||||
The inverse STFT.
|
||||
"""
|
||||
fft_size = int(fft_size)
|
||||
hopsamp = int(hopsamp)
|
||||
window = np.hanning(fft_size)
|
||||
time_slices = X.shape[0]
|
||||
len_samples = int(time_slices * hopsamp + fft_size)
|
||||
x = np.zeros(len_samples)
|
||||
for n, i in enumerate(range(0, len(x) - fft_size, hopsamp)):
|
||||
x[i : i + fft_size] += window * np.real(np.fft.irfft(X[n]))
|
||||
return x
|
||||
|
||||
|
||||
def inverse(magnitude, phase):
|
||||
recombine_magnitude_phase = np.concatenate(
|
||||
[magnitude * np.cos(phase), magnitude * np.sin(phase)], axis=1
|
||||
)
|
||||
|
||||
x_org = recombine_magnitude_phase
|
||||
n_b, n_f, n_t = x_org.shape # pylint: disable=unpacking-non-sequence
|
||||
x = np.empty([n_b, n_f // 2, n_t], dtype=np.complex64)
|
||||
x.real = x_org[:, : n_f // 2]
|
||||
x.imag = x_org[:, n_f // 2 :]
|
||||
inverse_transform = []
|
||||
for y in x:
|
||||
y_ = istft(y.T, fft_size=1024, hopsamp=256)
|
||||
inverse_transform.append(y_[None, :])
|
||||
|
||||
inverse_transform = np.concatenate(inverse_transform, 0)
|
||||
|
||||
return inverse_transform
|
||||
|
||||
|
||||
def transform(input_data):
|
||||
x = input_data
|
||||
real_part = []
|
||||
imag_part = []
|
||||
for y in x:
|
||||
y_ = stft(y, fft_size=1024, hopsamp=256).T
|
||||
real_part.append(y_.real[None, :, :]) # pylint: disable=unsubscriptable-object
|
||||
imag_part.append(y_.imag[None, :, :]) # pylint: disable=unsubscriptable-object
|
||||
real_part = np.concatenate(real_part, 0)
|
||||
imag_part = np.concatenate(imag_part, 0)
|
||||
|
||||
magnitude = np.sqrt(real_part**2 + imag_part**2)
|
||||
phase = np.arctan2(imag_part.data, real_part.data)
|
||||
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,92 @@
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
|
||||
from larynx_train.vits.mel_processing import spectrogram_torch
|
||||
|
||||
from .trim import trim_silence
|
||||
from .vad import SileroVoiceActivityDetector
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def make_silence_detector() -> SileroVoiceActivityDetector:
|
||||
silence_model = _DIR / "models" / "silero_vad.onnx"
|
||||
return SileroVoiceActivityDetector(silence_model)
|
||||
|
||||
|
||||
def cache_norm_audio(
|
||||
audio_path: Union[str, Path],
|
||||
cache_dir: Union[str, Path],
|
||||
detector: SileroVoiceActivityDetector,
|
||||
sample_rate: int,
|
||||
silence_threshold: float = 0.2,
|
||||
silence_samples_per_chunk: int = 480,
|
||||
silence_keep_chunks_before: int = 2,
|
||||
silence_keep_chunks_after: int = 2,
|
||||
filter_length: int = 1024,
|
||||
window_length: int = 1024,
|
||||
hop_length: int = 256,
|
||||
ignore_cache: bool = False,
|
||||
) -> Tuple[Path, Path]:
|
||||
audio_path = Path(audio_path).absolute()
|
||||
cache_dir = Path(cache_dir)
|
||||
|
||||
# Cache id is the SHA256 of the full audio path
|
||||
audio_cache_id = sha256(str(audio_path).encode()).hexdigest()
|
||||
|
||||
audio_norm_path = cache_dir / f"{audio_cache_id}.pt"
|
||||
audio_spec_path = cache_dir / f"{audio_cache_id}.spec.pt"
|
||||
|
||||
# Normalize audio
|
||||
audio_norm_tensor: Optional[torch.FloatTensor] = None
|
||||
if ignore_cache or (not audio_norm_path.exists()):
|
||||
# Trim silence first.
|
||||
#
|
||||
# The VAD model works on 16khz, so we determine the portion of audio
|
||||
# to keep and then just load that with librosa.
|
||||
vad_sample_rate = 16000
|
||||
audio_16khz, _sr = librosa.load(path=audio_path, sr=vad_sample_rate)
|
||||
|
||||
offset_sec, duration_sec = trim_silence(
|
||||
audio_16khz,
|
||||
detector,
|
||||
threshold=silence_threshold,
|
||||
samples_per_chunk=silence_samples_per_chunk,
|
||||
sample_rate=vad_sample_rate,
|
||||
keep_chunks_before=silence_keep_chunks_before,
|
||||
keep_chunks_after=silence_keep_chunks_after,
|
||||
)
|
||||
|
||||
# NOTE: audio is already in [-1, 1] coming from librosa
|
||||
audio_norm_array, _sr = librosa.load(
|
||||
path=audio_path,
|
||||
sr=sample_rate,
|
||||
offset=offset_sec,
|
||||
duration=duration_sec,
|
||||
)
|
||||
|
||||
# Save to cache directory
|
||||
audio_norm_tensor = torch.FloatTensor(audio_norm_array).unsqueeze(0)
|
||||
torch.save(audio_norm_tensor, audio_norm_path)
|
||||
|
||||
# Compute spectrogram
|
||||
if ignore_cache or (not audio_spec_path.exists()):
|
||||
if audio_norm_tensor is None:
|
||||
# Load pre-cached normalized audio
|
||||
audio_norm_tensor = torch.load(audio_norm_path)
|
||||
|
||||
audio_spec_tensor = spectrogram_torch(
|
||||
y=audio_norm_tensor,
|
||||
n_fft=filter_length,
|
||||
sampling_rate=sample_rate,
|
||||
hop_size=hop_length,
|
||||
win_size=window_length,
|
||||
center=False,
|
||||
).squeeze(0)
|
||||
torch.save(audio_spec_tensor, audio_spec_path)
|
||||
|
||||
return audio_norm_path, audio_spec_path
|
Binary file not shown.
@ -0,0 +1,54 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .vad import SileroVoiceActivityDetector
|
||||
|
||||
|
||||
def trim_silence(
|
||||
audio_array: np.ndarray,
|
||||
detector: SileroVoiceActivityDetector,
|
||||
threshold: float = 0.2,
|
||||
samples_per_chunk=480,
|
||||
sample_rate=16000,
|
||||
keep_chunks_before: int = 2,
|
||||
keep_chunks_after: int = 2,
|
||||
) -> Tuple[float, Optional[float]]:
|
||||
"""Returns the offset/duration of trimmed audio in seconds"""
|
||||
offset_sec: float = 0.0
|
||||
duration_sec: Optional[float] = None
|
||||
first_chunk: Optional[int] = None
|
||||
last_chunk: Optional[int] = None
|
||||
seconds_per_chunk: float = samples_per_chunk / sample_rate
|
||||
|
||||
chunk = audio_array[:samples_per_chunk]
|
||||
audio_array = audio_array[samples_per_chunk:]
|
||||
chunk_idx: int = 0
|
||||
|
||||
# Determine main block of speech
|
||||
while len(audio_array) > 0:
|
||||
prob = detector(chunk, sample_rate=sample_rate)
|
||||
is_speech = prob >= threshold
|
||||
|
||||
if is_speech:
|
||||
if first_chunk is None:
|
||||
# First speech
|
||||
first_chunk = chunk_idx
|
||||
else:
|
||||
# Last speech so far
|
||||
last_chunk = chunk_idx
|
||||
|
||||
chunk = audio_array[:samples_per_chunk]
|
||||
audio_array = audio_array[samples_per_chunk:]
|
||||
chunk_idx += 1
|
||||
|
||||
if (first_chunk is not None) and (last_chunk is not None):
|
||||
first_chunk = max(0, first_chunk - keep_chunks_before)
|
||||
last_chunk = min(chunk_idx, last_chunk + keep_chunks_after)
|
||||
|
||||
# Compute offset/duration
|
||||
offset_sec = first_chunk * seconds_per_chunk
|
||||
last_sec = (last_chunk + 1) * seconds_per_chunk
|
||||
duration_sec = last_sec - offset_sec
|
||||
|
||||
return offset_sec, duration_sec
|
@ -0,0 +1,54 @@
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
||||
|
||||
class SileroVoiceActivityDetector:
|
||||
"""Detects speech/silence using Silero VAD.
|
||||
|
||||
https://github.com/snakers4/silero-vad
|
||||
"""
|
||||
|
||||
def __init__(self, onnx_path: typing.Union[str, Path]):
|
||||
onnx_path = str(onnx_path)
|
||||
|
||||
self.session = onnxruntime.InferenceSession(onnx_path)
|
||||
self.session.intra_op_num_threads = 1
|
||||
self.session.inter_op_num_threads = 1
|
||||
|
||||
self._h = np.zeros((2, 1, 64)).astype("float32")
|
||||
self._c = np.zeros((2, 1, 64)).astype("float32")
|
||||
|
||||
def __call__(self, audio_array: np.ndarray, sample_rate: int = 16000):
|
||||
"""Return probability of speech in audio [0-1].
|
||||
|
||||
Audio must be 16Khz 16-bit mono PCM.
|
||||
"""
|
||||
if len(audio_array.shape) == 1:
|
||||
# Add batch dimension
|
||||
audio_array = np.expand_dims(audio_array, 0)
|
||||
|
||||
if len(audio_array.shape) > 2:
|
||||
raise ValueError(
|
||||
f"Too many dimensions for input audio chunk {audio_array.shape}"
|
||||
)
|
||||
|
||||
if audio_array.shape[0] > 1:
|
||||
raise ValueError("Onnx model does not support batching")
|
||||
|
||||
if sample_rate != 16000:
|
||||
raise ValueError("Only 16Khz audio is supported")
|
||||
|
||||
ort_inputs = {
|
||||
"input": audio_array.astype(np.float32),
|
||||
"h0": self._h,
|
||||
"c0": self._c,
|
||||
}
|
||||
ort_outs = self.session.run(None, ort_inputs)
|
||||
out, self._h, self._c = ort_outs
|
||||
|
||||
out = out.squeeze(2)[:, 1] # make output type match JIT analog
|
||||
|
||||
return out
|
@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
|
||||
from espeak_phonemizer import Phonemizer
|
||||
|
||||
from .norm_audio import cache_norm_audio, make_silence_detector
|
||||
from .phonemize import DEFAULT_PHONEME_ID_MAP, phonemes_to_ids, phonemize
|
||||
|
||||
_LOGGER = logging.getLogger("preprocess")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input-dir", required=True, help="Directory with audio dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
required=True,
|
||||
help="Directory to write output files for training",
|
||||
)
|
||||
parser.add_argument("--language", required=True, help="eSpeak-ng voice")
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Target sample rate for voice (hertz)",
|
||||
)
|
||||
parser.add_argument("--cache-dir", help="Directory to cache processed audio files")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
# Convert to paths and create output directories
|
||||
args.input_dir = Path(args.input_dir)
|
||||
args.output_dir = Path(args.output_dir)
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
args.cache_dir = (
|
||||
Path(args.cache_dir)
|
||||
if args.cache_dir
|
||||
else args.output_dir / "cache" / str(args.sample_rate)
|
||||
)
|
||||
args.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Count speakers
|
||||
_LOGGER.info("Counting number of speakers in the dataset")
|
||||
speakers: Set[str] = set()
|
||||
for utt in mycroft_dataset(args.input_dir):
|
||||
speakers.add(utt.speaker or "")
|
||||
|
||||
is_multispeaker = len(speakers) > 1
|
||||
speaker_ids: Dict[str, int] = {}
|
||||
|
||||
if is_multispeaker:
|
||||
_LOGGER.info("%s speaker(s) detected", len(speakers))
|
||||
|
||||
# Assign speaker ids in sorted order
|
||||
for speaker_id, speaker in enumerate(sorted(speakers)):
|
||||
speaker_ids[speaker] = speaker_id
|
||||
else:
|
||||
_LOGGER.info("Single speaker dataset")
|
||||
|
||||
# Write config
|
||||
with open(args.output_dir / "config.json", "w", encoding="utf-8") as config_file:
|
||||
json.dump(
|
||||
{
|
||||
"audio": {
|
||||
"sample_rate": args.sample_rate,
|
||||
},
|
||||
"espeak": {
|
||||
"voice": args.language,
|
||||
},
|
||||
"inference": {"noise_scale": 0.667, "length_scale": 1, "noise_w": 0.8},
|
||||
"phoneme_map": {},
|
||||
"phoneme_id_map": DEFAULT_PHONEME_ID_MAP,
|
||||
"num_symbols": len(
|
||||
set(itertools.chain.from_iterable(DEFAULT_PHONEME_ID_MAP.values()))
|
||||
),
|
||||
"num_speakers": len(speakers),
|
||||
"speaker_id_map": speaker_ids,
|
||||
},
|
||||
config_file,
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
)
|
||||
_LOGGER.info("Wrote dataset config")
|
||||
|
||||
# Used to trim silence
|
||||
silence_detector = make_silence_detector()
|
||||
|
||||
with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
|
||||
phonemizer = Phonemizer(default_voice=args.language)
|
||||
for utt in mycroft_dataset(args.input_dir):
|
||||
try:
|
||||
utt.audio_path = utt.audio_path.absolute()
|
||||
_LOGGER.debug(utt)
|
||||
|
||||
utt.phonemes = phonemize(utt.text, phonemizer)
|
||||
utt.phoneme_ids = phonemes_to_ids(utt.phonemes)
|
||||
utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
|
||||
utt.audio_path, args.cache_dir, silence_detector, args.sample_rate
|
||||
)
|
||||
|
||||
# JSONL
|
||||
json.dump(
|
||||
dataclasses.asdict(utt),
|
||||
dataset_file,
|
||||
ensure_ascii=False,
|
||||
cls=PathEncoder,
|
||||
)
|
||||
print("", file=dataset_file)
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to process utterance: %s", utt)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class Utterance:
|
||||
text: str
|
||||
audio_path: Path
|
||||
speaker: Optional[str] = None
|
||||
phonemes: Optional[List[str]] = None
|
||||
phoneme_ids: Optional[List[int]] = None
|
||||
audio_norm_path: Optional[Path] = None
|
||||
audio_spec_path: Optional[Path] = None
|
||||
|
||||
|
||||
class PathEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, Path):
|
||||
return str(o)
|
||||
return super().default(o)
|
||||
|
||||
|
||||
def mycroft_dataset(dataset_dir: Path) -> Iterable[Utterance]:
|
||||
for info_path in dataset_dir.glob("*.info"):
|
||||
wav_path = info_path.with_suffix(".wav")
|
||||
if wav_path.exists():
|
||||
text = info_path.read_text(encoding="utf-8").strip()
|
||||
yield Utterance(text=text, audio_path=wav_path)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,40 @@
|
||||
[MESSAGES CONTROL]
|
||||
disable=
|
||||
format,
|
||||
abstract-class-little-used,
|
||||
abstract-method,
|
||||
cyclic-import,
|
||||
duplicate-code,
|
||||
global-statement,
|
||||
import-outside-toplevel,
|
||||
inconsistent-return-statements,
|
||||
locally-disabled,
|
||||
not-context-manager,
|
||||
redefined-variable-type,
|
||||
too-few-public-methods,
|
||||
too-many-arguments,
|
||||
too-many-branches,
|
||||
too-many-instance-attributes,
|
||||
too-many-lines,
|
||||
too-many-locals,
|
||||
too-many-public-methods,
|
||||
too-many-return-statements,
|
||||
too-many-statements,
|
||||
too-many-boolean-expressions,
|
||||
unnecessary-pass,
|
||||
unused-argument,
|
||||
broad-except,
|
||||
too-many-nested-blocks,
|
||||
invalid-name,
|
||||
unused-import,
|
||||
no-self-use,
|
||||
fixme,
|
||||
useless-super-delegation,
|
||||
missing-module-docstring,
|
||||
missing-class-docstring,
|
||||
missing-function-docstring,
|
||||
import-error,
|
||||
relative-beyond-top-level
|
||||
|
||||
[FORMAT]
|
||||
expected-line-ending-format=LF
|
@ -0,0 +1,22 @@
|
||||
[flake8]
|
||||
# To work with Black
|
||||
max-line-length = 88
|
||||
# E501: line too long
|
||||
# W503: Line break occurred before a binary operator
|
||||
# E203: Whitespace before ':'
|
||||
# D202 No blank lines allowed after function docstring
|
||||
# W504 line break after binary operator
|
||||
ignore =
|
||||
E501,
|
||||
W503,
|
||||
E203,
|
||||
D202,
|
||||
W504
|
||||
|
||||
[isort]
|
||||
multi_line_output = 3
|
||||
include_trailing_comma=True
|
||||
force_grid_wrap=0
|
||||
use_parentheses=True
|
||||
line_length=88
|
||||
indent = " "
|
@ -0,0 +1,417 @@
|
||||
import math
|
||||
import typing
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .commons import subsequent_mask
|
||||
from .modules import LayerNorm
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int = 1,
|
||||
p_dropout: float = 0.0,
|
||||
window_size: int = 4,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
self.ffn_layers = nn.ModuleList()
|
||||
self.norm_layers_2 = nn.ModuleList()
|
||||
for i in range(self.n_layers):
|
||||
self.attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
n_heads,
|
||||
p_dropout=p_dropout,
|
||||
window_size=window_size,
|
||||
)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
FFN(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=p_dropout,
|
||||
)
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int = 1,
|
||||
p_dropout: float = 0.0,
|
||||
proximal_bias: bool = False,
|
||||
proximal_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.self_attn_layers = nn.ModuleList()
|
||||
self.norm_layers_0 = nn.ModuleList()
|
||||
self.encdec_attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
self.ffn_layers = nn.ModuleList()
|
||||
self.norm_layers_2 = nn.ModuleList()
|
||||
for i in range(self.n_layers):
|
||||
self.self_attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
n_heads,
|
||||
p_dropout=p_dropout,
|
||||
proximal_bias=proximal_bias,
|
||||
proximal_init=proximal_init,
|
||||
)
|
||||
)
|
||||
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
||||
self.encdec_attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
||||
)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
FFN(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=p_dropout,
|
||||
causal=True,
|
||||
)
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask, h, h_mask):
|
||||
"""
|
||||
x: decoder input
|
||||
h: encoder output
|
||||
"""
|
||||
self_attn_mask = subsequent_mask(x_mask.size(2)).type_as(x)
|
||||
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_0[i](x + y)
|
||||
|
||||
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
out_channels: int,
|
||||
n_heads: int,
|
||||
p_dropout: float = 0.0,
|
||||
window_size: typing.Optional[int] = None,
|
||||
heads_share: bool = True,
|
||||
block_length: typing.Optional[int] = None,
|
||||
proximal_bias: bool = False,
|
||||
proximal_init: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert channels % n_heads == 0
|
||||
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels
|
||||
self.n_heads = n_heads
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
self.heads_share = heads_share
|
||||
self.block_length = block_length
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
self.attn = None
|
||||
|
||||
self.k_channels = channels // n_heads
|
||||
self.conv_q = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_k = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_v = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_v = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
if proximal_init:
|
||||
with torch.no_grad():
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, t_t = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
||||
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
|
||||
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
||||
if self.window_size is not None:
|
||||
assert (
|
||||
t_s == t_t
|
||||
), "Relative attention is only available for self-attention."
|
||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(
|
||||
query / math.sqrt(self.k_channels), key_relative_embeddings
|
||||
)
|
||||
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
||||
scores = scores + scores_local
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).type_as(scores)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
if self.block_length is not None:
|
||||
assert (
|
||||
t_s == t_t
|
||||
), "Local attention is only available for self-attention."
|
||||
block_mask = (
|
||||
torch.ones_like(scores)
|
||||
.triu(-self.block_length)
|
||||
.tril(self.block_length)
|
||||
)
|
||||
scores = scores.masked_fill(block_mask == 0, -1e4)
|
||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||
p_attn = self.drop(p_attn)
|
||||
output = torch.matmul(p_attn, value)
|
||||
if self.window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(
|
||||
self.emb_rel_v, t_s
|
||||
)
|
||||
output = output + self._matmul_with_relative_values(
|
||||
relative_weights, value_relative_embeddings
|
||||
)
|
||||
output = (
|
||||
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
||||
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
"""
|
||||
x: [b, h, l, m]
|
||||
y: [h or 1, m, d]
|
||||
ret: [b, h, l, d]
|
||||
"""
|
||||
ret = torch.matmul(x, y.unsqueeze(0))
|
||||
return ret
|
||||
|
||||
def _matmul_with_relative_keys(self, x, y):
|
||||
"""
|
||||
x: [b, h, l, d]
|
||||
y: [h or 1, m, d]
|
||||
ret: [b, h, l, m]
|
||||
"""
|
||||
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
||||
return ret
|
||||
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
# max_relative_position = 2 * self.window_size + 1
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_length = max(length - (self.window_size + 1), 0)
|
||||
slice_start_position = max((self.window_size + 1) - length, 0)
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
if pad_length > 0:
|
||||
padded_relative_embeddings = F.pad(
|
||||
relative_embeddings,
|
||||
# convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
||||
(0, 0, pad_length, pad_length, 0, 0),
|
||||
)
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[
|
||||
:, slice_start_position:slice_end_position
|
||||
]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
"""
|
||||
x: [b, h, l, 2*l-1]
|
||||
ret: [b, h, l, l]
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
|
||||
# Concat columns of pad to shift from relative to absolute indexing.
|
||||
# x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
||||
x = F.pad(x, (0, 1, 0, 0, 0, 0, 0, 0))
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
# x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||
x_flat = F.pad(x_flat, (0, length - 1, 0, 0, 0, 0))
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length + 1, (2 * length) - 1])[
|
||||
:, :, :length, length - 1 :
|
||||
]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
"""
|
||||
x: [b, h, l, l]
|
||||
ret: [b, h, l, 2*l-1]
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
|
||||
# padd along column
|
||||
# x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||
x = F.pad(x, (0, length - 1, 0, 0, 0, 0, 0, 0))
|
||||
x_flat = x.view([batch, heads, (length * length) + (length * (length - 1))])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
# x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
x_flat = F.pad(x_flat, (length, 0, 0, 0, 0, 0))
|
||||
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
||||
return x_final
|
||||
|
||||
def _attention_bias_proximal(self, length):
|
||||
"""Bias for self-attention to encourage attention to close positions.
|
||||
Args:
|
||||
length: an integer scalar.
|
||||
Returns:
|
||||
a Tensor with shape [1, 1, length, length]
|
||||
"""
|
||||
r = torch.arange(length, dtype=torch.float32)
|
||||
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
||||
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
filter_channels: int,
|
||||
kernel_size: int,
|
||||
p_dropout: float = 0.0,
|
||||
activation: typing.Optional[str] = None,
|
||||
causal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.activation = activation
|
||||
self.causal = causal
|
||||
|
||||
if causal:
|
||||
self.padding = self._causal_padding
|
||||
else:
|
||||
self.padding = self._same_padding
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(self.padding(x * x_mask))
|
||||
if self.activation == "gelu":
|
||||
x = x * torch.sigmoid(1.702 * x)
|
||||
else:
|
||||
x = torch.relu(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(self.padding(x * x_mask))
|
||||
return x * x_mask
|
||||
|
||||
def _causal_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = self.kernel_size - 1
|
||||
pad_r = 0
|
||||
# padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
# x = F.pad(x, convert_pad_shape(padding))
|
||||
x = F.pad(x, (pad_l, pad_r, 0, 0, 0, 0))
|
||||
return x
|
||||
|
||||
def _same_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = (self.kernel_size - 1) // 2
|
||||
pad_r = self.kernel_size // 2
|
||||
# padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
# x = F.pad(x, convert_pad_shape(padding))
|
||||
x = F.pad(x, (pad_l, pad_r, 0, 0, 0, 0))
|
||||
return x
|
@ -0,0 +1,146 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
_LOGGER = logging.getLogger("vits.commons")
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += (
|
||||
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
)
|
||||
return kl
|
||||
|
||||
|
||||
def rand_gumbel(shape):
|
||||
"""Sample from the Gumbel distribution, protect from overflows."""
|
||||
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
||||
return -torch.log(-torch.log(uniform_samples))
|
||||
|
||||
|
||||
def rand_gumbel_like(x):
|
||||
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
||||
return g
|
||||
|
||||
|
||||
def slice_segments(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = max(0, ids_str[i])
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
|
||||
|
||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
||||
num_timescales - 1
|
||||
)
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
||||
)
|
||||
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
||||
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
||||
signal = signal.view(1, channels, length)
|
||||
return signal
|
||||
|
||||
|
||||
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return x + signal.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
||||
|
||||
|
||||
def subsequent_mask(length):
|
||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||
return mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).type_as(mask)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2, 3) * mask
|
||||
return path
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
norm_type = float(norm_type)
|
||||
if clip_value is not None:
|
||||
clip_value = float(clip_value)
|
||||
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
if clip_value is not None:
|
||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
return total_norm
|
@ -0,0 +1,330 @@
|
||||
"""Configuration classes"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class MelAudioConfig:
|
||||
filter_length: int = 1024
|
||||
hop_length: int = 256
|
||||
win_length: int = 1024
|
||||
mel_channels: int = 80
|
||||
sample_rate: int = 22050
|
||||
sample_bytes: int = 2
|
||||
channels: int = 1
|
||||
mel_fmin: float = 0.0
|
||||
mel_fmax: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelAudioConfig:
|
||||
resblock: str
|
||||
resblock_kernel_sizes: Tuple[int, ...]
|
||||
resblock_dilation_sizes: Tuple[Tuple[int, ...], ...]
|
||||
upsample_rates: Tuple[int, ...]
|
||||
upsample_initial_channel: int
|
||||
upsample_kernel_sizes: Tuple[int, ...]
|
||||
|
||||
@staticmethod
|
||||
def low_quality() -> "ModelAudioConfig":
|
||||
return ModelAudioConfig(
|
||||
resblock="2",
|
||||
resblock_kernel_sizes=(3, 5, 7),
|
||||
resblock_dilation_sizes=(
|
||||
(1, 2),
|
||||
(2, 6),
|
||||
(3, 12),
|
||||
),
|
||||
upsample_rates=(8, 8, 4),
|
||||
upsample_initial_channel=256,
|
||||
upsample_kernel_sizes=(16, 16, 8),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def high_quality() -> "ModelAudioConfig":
|
||||
return ModelAudioConfig(
|
||||
resblock="1",
|
||||
resblock_kernel_sizes=(3, 7, 11),
|
||||
resblock_dilation_sizes=(
|
||||
(1, 3, 5),
|
||||
(1, 3, 5),
|
||||
(1, 3, 5),
|
||||
),
|
||||
upsample_rates=(8, 8, 2, 2),
|
||||
upsample_initial_channel=512,
|
||||
upsample_kernel_sizes=(16, 16, 4, 4),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
num_symbols: int
|
||||
n_speakers: int
|
||||
audio: ModelAudioConfig
|
||||
mel: MelAudioConfig = field(default_factory=MelAudioConfig)
|
||||
|
||||
inter_channels: int = 192
|
||||
hidden_channels: int = 192
|
||||
filter_channels: int = 768
|
||||
n_heads: int = 2
|
||||
n_layers: int = 6
|
||||
kernel_size: int = 3
|
||||
p_dropout: float = 0.1
|
||||
n_layers_q: int = 3
|
||||
use_spectral_norm: bool = False
|
||||
gin_channels: int = 0 # single speaker
|
||||
use_sdp: bool = True # StochasticDurationPredictor
|
||||
segment_size: int = 8192
|
||||
|
||||
@property
|
||||
def is_multispeaker(self) -> bool:
|
||||
return self.n_speakers > 1
|
||||
|
||||
@property
|
||||
def resblock(self) -> str:
|
||||
return self.audio.resblock
|
||||
|
||||
@property
|
||||
def resblock_kernel_sizes(self) -> Tuple[int, ...]:
|
||||
return self.audio.resblock_kernel_sizes
|
||||
|
||||
@property
|
||||
def resblock_dilation_sizes(self) -> Tuple[Tuple[int, ...], ...]:
|
||||
return self.audio.resblock_dilation_sizes
|
||||
|
||||
@property
|
||||
def upsample_rates(self) -> Tuple[int, ...]:
|
||||
return self.audio.upsample_rates
|
||||
|
||||
@property
|
||||
def upsample_initial_channel(self) -> int:
|
||||
return self.audio.upsample_initial_channel
|
||||
|
||||
@property
|
||||
def upsample_kernel_sizes(self) -> Tuple[int, ...]:
|
||||
return self.audio.upsample_kernel_sizes
|
||||
|
||||
def __post_init__(self):
|
||||
if self.is_multispeaker and (self.gin_channels == 0):
|
||||
self.gin_channels = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
learning_rate: float = 2e-4
|
||||
betas: Tuple[float, float] = field(default=(0.8, 0.99))
|
||||
eps: float = 1e-9
|
||||
# batch_size: int = 32
|
||||
fp16_run: bool = False
|
||||
lr_decay: float = 0.999875
|
||||
init_lr_ratio: float = 1.0
|
||||
warmup_epochs: int = 0
|
||||
c_mel: int = 45
|
||||
c_kl: float = 1.0
|
||||
grad_clip: Optional[float] = None
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class PhonemesConfig(DataClassJsonMixin):
|
||||
# phoneme_separator: str = " "
|
||||
# """Separator between individual phonemes in CSV input"""
|
||||
|
||||
# word_separator: str = "#"
|
||||
# """Separator between word phonemes in CSV input (must not match phoneme_separator)"""
|
||||
|
||||
# phoneme_to_id: typing.Optional[typing.Dict[str, int]] = None
|
||||
# pad: typing.Optional[str] = "_"
|
||||
# bos: typing.Optional[str] = None
|
||||
# eos: typing.Optional[str] = None
|
||||
# blank: typing.Optional[str] = "#"
|
||||
# blank_word: typing.Optional[str] = None
|
||||
# blank_between: typing.Union[str, BlankBetween] = BlankBetween.WORDS
|
||||
# blank_at_start: bool = True
|
||||
# blank_at_end: bool = True
|
||||
# simple_punctuation: bool = True
|
||||
# punctuation_map: typing.Optional[typing.Dict[str, str]] = None
|
||||
# separate: typing.Optional[typing.List[str]] = None
|
||||
# separate_graphemes: bool = False
|
||||
# separate_tones: bool = False
|
||||
# tone_before: bool = False
|
||||
# phoneme_map: typing.Optional[typing.Dict[str, typing.List[str]]] = None
|
||||
# auto_bos_eos: bool = False
|
||||
# minor_break: typing.Optional[str] = IPA.BREAK_MINOR.value
|
||||
# major_break: typing.Optional[str] = IPA.BREAK_MAJOR.value
|
||||
# break_phonemes_into_graphemes: bool = False
|
||||
# break_phonemes_into_codepoints: bool = False
|
||||
# drop_stress: bool = False
|
||||
# symbols: typing.Optional[typing.List[str]] = None
|
||||
|
||||
# def split_word_phonemes(self, phonemes_str: str) -> typing.List[typing.List[str]]:
|
||||
# """Split phonemes string into a list of lists (outer is words, inner is individual phonemes in each word)"""
|
||||
# return [
|
||||
# word_phonemes_str.split(self.phoneme_separator)
|
||||
# if self.phoneme_separator
|
||||
# else list(word_phonemes_str)
|
||||
# for word_phonemes_str in phonemes_str.split(self.word_separator)
|
||||
# ]
|
||||
|
||||
# def join_word_phonemes(self, word_phonemes: typing.List[typing.List[str]]) -> str:
|
||||
# """Split phonemes string into a list of lists (outer is words, inner is individual phonemes in each word)"""
|
||||
# return self.word_separator.join(
|
||||
# self.phoneme_separator.join(wp) for wp in word_phonemes
|
||||
# )
|
||||
|
||||
|
||||
# class Phonemizer(str, Enum):
|
||||
# SYMBOLS = "symbols"
|
||||
# GRUUT = "gruut"
|
||||
# ESPEAK = "espeak"
|
||||
# EPITRAN = "epitran"
|
||||
|
||||
|
||||
# class Aligner(str, Enum):
|
||||
# KALDI_ALIGN = "kaldi_align"
|
||||
|
||||
|
||||
# class TextCasing(str, Enum):
|
||||
# LOWER = "lower"
|
||||
# UPPER = "upper"
|
||||
|
||||
|
||||
# class MetadataFormat(str, Enum):
|
||||
# TEXT = "text"
|
||||
# PHONEMES = "phonemes"
|
||||
# PHONEME_IDS = "ids"
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class DatasetConfig:
|
||||
# name: str
|
||||
# metadata_format: MetadataFormat = MetadataFormat.TEXT
|
||||
# multispeaker: bool = False
|
||||
# text_language: typing.Optional[str] = None
|
||||
# audio_dir: typing.Optional[typing.Union[str, Path]] = None
|
||||
# cache_dir: typing.Optional[typing.Union[str, Path]] = None
|
||||
|
||||
# def get_cache_dir(self, output_dir: typing.Union[str, Path]) -> Path:
|
||||
# if self.cache_dir is not None:
|
||||
# cache_dir = Path(self.cache_dir)
|
||||
# else:
|
||||
# cache_dir = Path("cache") / self.name
|
||||
|
||||
# if not cache_dir.is_absolute():
|
||||
# cache_dir = Path(output_dir) / str(cache_dir)
|
||||
|
||||
# return cache_dir
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class AlignerConfig:
|
||||
# aligner: typing.Optional[Aligner] = None
|
||||
# casing: typing.Optional[TextCasing] = None
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class InferenceConfig:
|
||||
# length_scale: float = 1.0
|
||||
# noise_scale: float = 0.667
|
||||
# noise_w: float = 0.8
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class TrainingConfig(DataClassJsonMixin):
|
||||
# seed: int = 1234
|
||||
# epochs: int = 10000
|
||||
# learning_rate: float = 2e-4
|
||||
# betas: typing.Tuple[float, float] = field(default=(0.8, 0.99))
|
||||
# eps: float = 1e-9
|
||||
# batch_size: int = 32
|
||||
# fp16_run: bool = False
|
||||
# lr_decay: float = 0.999875
|
||||
# segment_size: int = 8192
|
||||
# init_lr_ratio: float = 1.0
|
||||
# warmup_epochs: int = 0
|
||||
# c_mel: int = 45
|
||||
# c_kl: float = 1.0
|
||||
# grad_clip: typing.Optional[float] = None
|
||||
|
||||
# min_seq_length: typing.Optional[int] = None
|
||||
# max_seq_length: typing.Optional[int] = None
|
||||
|
||||
# min_spec_length: typing.Optional[int] = None
|
||||
# max_spec_length: typing.Optional[int] = None
|
||||
|
||||
# min_speaker_utterances: typing.Optional[int] = None
|
||||
|
||||
# last_epoch: int = 1
|
||||
# global_step: int = 1
|
||||
# best_loss: typing.Optional[float] = None
|
||||
# audio: AudioConfig = field(default_factory=AudioConfig)
|
||||
# model: ModelConfig = field(default_factory=ModelConfig)
|
||||
# phonemes: PhonemesConfig = field(default_factory=PhonemesConfig)
|
||||
# text_aligner: AlignerConfig = field(default_factory=AlignerConfig)
|
||||
# text_language: typing.Optional[str] = None
|
||||
# phonemizer: typing.Optional[Phonemizer] = None
|
||||
# datasets: typing.List[DatasetConfig] = field(default_factory=list)
|
||||
# inference: InferenceConfig = field(default_factory=InferenceConfig)
|
||||
|
||||
# version: int = 1
|
||||
# git_commit: str = ""
|
||||
|
||||
# @property
|
||||
# def is_multispeaker(self):
|
||||
# return self.model.is_multispeaker or any(d.multispeaker for d in self.datasets)
|
||||
|
||||
# def save(self, config_file: typing.TextIO):
|
||||
# """Save config as JSON to a file"""
|
||||
# json.dump(self.to_dict(), config_file, indent=4)
|
||||
|
||||
# def get_speaker_id(self, dataset_name: str, speaker_name: str) -> int:
|
||||
# if self.speaker_id_map is None:
|
||||
# self.speaker_id_map = {}
|
||||
|
||||
# full_speaker_name = f"{dataset_name}_{speaker_name}"
|
||||
# speaker_id = self.speaker_id_map.get(full_speaker_name)
|
||||
# if speaker_id is None:
|
||||
# speaker_id = len(self.speaker_id_map)
|
||||
# self.speaker_id_map[full_speaker_name] = speaker_id
|
||||
|
||||
# return speaker_id
|
||||
|
||||
# @staticmethod
|
||||
# def load(config_file: typing.TextIO) -> "TrainingConfig":
|
||||
# """Load config from a JSON file"""
|
||||
# return TrainingConfig.from_json(config_file.read())
|
||||
|
||||
# @staticmethod
|
||||
# def load_and_merge(
|
||||
# config: "TrainingConfig",
|
||||
# config_files: typing.Iterable[typing.Union[str, Path, typing.TextIO]],
|
||||
# ) -> "TrainingConfig":
|
||||
# """Loads one or more JSON configuration files and overlays them on top of an existing config"""
|
||||
# base_dict = config.to_dict()
|
||||
# for maybe_config_file in config_files:
|
||||
# if isinstance(maybe_config_file, (str, Path)):
|
||||
# # File path
|
||||
# config_file = open(maybe_config_file, "r", encoding="utf-8")
|
||||
# else:
|
||||
# # File object
|
||||
# config_file = maybe_config_file
|
||||
|
||||
# with config_file:
|
||||
# # Load new config and overlay on existing config
|
||||
# new_dict = json.load(config_file)
|
||||
# TrainingConfig.recursive_update(base_dict, new_dict)
|
||||
|
||||
# return TrainingConfig.from_dict(base_dict)
|
||||
|
||||
# @staticmethod
|
||||
# def recursive_update(
|
||||
# base_dict: typing.Dict[typing.Any, typing.Any],
|
||||
# new_dict: typing.Mapping[typing.Any, typing.Any],
|
||||
# ) -> None:
|
||||
# """Recursively overwrites values in base dictionary with values from new dictionary"""
|
||||
# for key, value in new_dict.items():
|
||||
# if isinstance(value, collections.Mapping) and (
|
||||
# base_dict.get(key) is not None
|
||||
# ):
|
||||
# TrainingConfig.recursive_update(base_dict[key], value)
|
||||
# else:
|
||||
# base_dict[key] = value
|
@ -0,0 +1,208 @@
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from torch import FloatTensor, LongTensor
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger("vits.dataset")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Utterance:
|
||||
phoneme_ids: List[int]
|
||||
audio_norm_path: Path
|
||||
audio_spec_path: Path
|
||||
speaker_id: Optional[int] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UtteranceTensors:
|
||||
phoneme_ids: LongTensor
|
||||
spectrogram: FloatTensor
|
||||
audio_norm: FloatTensor
|
||||
speaker_id: Optional[LongTensor] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
@property
|
||||
def spec_length(self) -> int:
|
||||
return self.spectrogram.size(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Batch:
|
||||
phoneme_ids: LongTensor
|
||||
phoneme_lengths: LongTensor
|
||||
spectrograms: FloatTensor
|
||||
spectrogram_lengths: LongTensor
|
||||
audios: FloatTensor
|
||||
audio_lengths: LongTensor
|
||||
speaker_ids: Optional[LongTensor] = None
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class LarynxDatasetSettings:
|
||||
# sample_rate: int
|
||||
# is_multispeaker: bool
|
||||
# espeak_voice: Optional[str] = None
|
||||
# phoneme_map: Dict[str, Optional[List[str]]] = field(default_factory=dict)
|
||||
# phoneme_id_map: Dict[str, List[int]] = DEFAULT_PHONEME_ID_MAP
|
||||
|
||||
|
||||
class LarynxDataset(Dataset):
|
||||
"""
|
||||
Dataset format:
|
||||
|
||||
* phoneme_ids (required)
|
||||
* audio_norm_path (required)
|
||||
* audio_spec_path (required)
|
||||
* text (optional)
|
||||
* phonemes (optional)
|
||||
* audio_path (optional)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_paths: List[Union[str, Path]], # settings: LarynxDatasetSettings
|
||||
):
|
||||
# self.settings = settings
|
||||
self.utterances: List[Utterance] = []
|
||||
|
||||
for dataset_path in dataset_paths:
|
||||
dataset_path = Path(dataset_path)
|
||||
_LOGGER.debug("Loading dataset: %s", dataset_path)
|
||||
self.utterances.extend(LarynxDataset.load_dataset(dataset_path))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.utterances)
|
||||
|
||||
def __getitem__(self, idx) -> UtteranceTensors:
|
||||
utt = self.utterances[idx]
|
||||
return UtteranceTensors(
|
||||
phoneme_ids=LongTensor(utt.phoneme_ids),
|
||||
audio_norm=torch.load(utt.audio_norm_path),
|
||||
spectrogram=torch.load(utt.audio_spec_path),
|
||||
speaker_id=LongTensor([utt.speaker_id])
|
||||
if utt.speaker_id is not None
|
||||
else None,
|
||||
text=utt.text,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_dataset(dataset_path: Path) -> Iterable[Utterance]:
|
||||
with open(dataset_path, "r", encoding="utf-8") as dataset_file:
|
||||
for line_idx, line in enumerate(dataset_file):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
yield LarynxDataset.load_utterance(line)
|
||||
except Exception:
|
||||
_LOGGER.exception(
|
||||
"Error on line %s of %s: %s",
|
||||
line_idx + 1,
|
||||
dataset_path,
|
||||
line,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_utterance(line: str) -> Utterance:
|
||||
utt_dict = json.loads(line)
|
||||
return Utterance(
|
||||
phoneme_ids=utt_dict["phoneme_ids"],
|
||||
audio_norm_path=Path(utt_dict["audio_norm_path"]),
|
||||
audio_spec_path=Path(utt_dict["audio_spec_path"]),
|
||||
speaker_id=utt_dict.get("speaker_id"),
|
||||
text=utt_dict.get("text"),
|
||||
)
|
||||
|
||||
|
||||
class UtteranceCollate:
|
||||
def __init__(self, is_multispeaker: bool, segment_size: int):
|
||||
self.is_multispeaker = is_multispeaker
|
||||
self.segment_size = segment_size
|
||||
|
||||
def __call__(self, utterances: Sequence[UtteranceTensors]) -> Batch:
|
||||
num_utterances = len(utterances)
|
||||
assert num_utterances > 0, "No utterances"
|
||||
|
||||
max_phonemes_length = 0
|
||||
max_spec_length = 0
|
||||
max_audio_length = 0
|
||||
|
||||
num_mels = 0
|
||||
|
||||
# Determine lengths
|
||||
for utt_idx, utt in enumerate(utterances):
|
||||
assert utt.spectrogram is not None
|
||||
assert utt.audio_norm is not None
|
||||
|
||||
phoneme_length = utt.phoneme_ids.size(0)
|
||||
spec_length = utt.spectrogram.size(1)
|
||||
audio_length = utt.audio_norm.size(1)
|
||||
|
||||
max_phonemes_length = max(max_phonemes_length, phoneme_length)
|
||||
max_spec_length = max(max_spec_length, spec_length)
|
||||
max_audio_length = max(max_audio_length, audio_length)
|
||||
|
||||
num_mels = utt.spectrogram.size(0)
|
||||
if self.is_multispeaker:
|
||||
assert utt.speaker_id is not None, "Missing speaker id"
|
||||
|
||||
# Audio cannot be smaller than segment size (8192)
|
||||
max_audio_length = max(max_audio_length, self.segment_size)
|
||||
|
||||
# Create padded tensors
|
||||
phonemes_padded = LongTensor(num_utterances, max_phonemes_length)
|
||||
spec_padded = FloatTensor(num_utterances, num_mels, max_spec_length)
|
||||
audio_padded = FloatTensor(num_utterances, 1, max_audio_length)
|
||||
|
||||
phonemes_padded.zero_()
|
||||
spec_padded.zero_()
|
||||
audio_padded.zero_()
|
||||
|
||||
phoneme_lengths = LongTensor(num_utterances)
|
||||
spec_lengths = LongTensor(num_utterances)
|
||||
audio_lengths = LongTensor(num_utterances)
|
||||
|
||||
speaker_ids: Optional[LongTensor] = None
|
||||
if self.is_multispeaker:
|
||||
speaker_ids = LongTensor(num_utterances)
|
||||
|
||||
# Sort by decreasing spectrogram length
|
||||
sorted_utterances = sorted(
|
||||
utterances, key=lambda u: u.spectrogram.size(1), reverse=True
|
||||
)
|
||||
for utt_idx, utt in enumerate(sorted_utterances):
|
||||
phoneme_length = utt.phoneme_ids.size(0)
|
||||
spec_length = utt.spectrogram.size(1)
|
||||
audio_length = utt.audio_norm.size(1)
|
||||
|
||||
phonemes_padded[utt_idx, :phoneme_length] = utt.phoneme_ids
|
||||
phoneme_lengths[utt_idx] = phoneme_length
|
||||
|
||||
spec_padded[utt_idx, :, :spec_length] = utt.spectrogram
|
||||
spec_lengths[utt_idx] = spec_length
|
||||
|
||||
audio_padded[utt_idx, :, :audio_length] = utt.audio_norm
|
||||
audio_lengths[utt_idx] = audio_length
|
||||
|
||||
if utt.speaker_id is not None:
|
||||
assert speaker_ids is not None
|
||||
speaker_ids[utt_idx] = utt.speaker_id
|
||||
|
||||
return Batch(
|
||||
phoneme_ids=phonemes_padded,
|
||||
phoneme_lengths=phoneme_lengths,
|
||||
spectrograms=spec_padded,
|
||||
spectrogram_lengths=spec_lengths,
|
||||
audios=audio_padded,
|
||||
audio_lengths=audio_lengths,
|
||||
speaker_ids=speaker_ids,
|
||||
)
|
@ -0,0 +1,330 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import autocast
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
|
||||
from .commons import slice_segments
|
||||
from .dataset import Batch, LarynxDataset, UtteranceCollate
|
||||
from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss
|
||||
from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||
from .models import MultiPeriodDiscriminator, SynthesizerTrn
|
||||
|
||||
_LOGGER = logging.getLogger("vits.lightning")
|
||||
|
||||
|
||||
class VitsModel(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
num_symbols: int,
|
||||
num_speakers: int,
|
||||
# audio
|
||||
resblock="2",
|
||||
resblock_kernel_sizes=(3, 5, 7),
|
||||
resblock_dilation_sizes=(
|
||||
(1, 2),
|
||||
(2, 6),
|
||||
(3, 12),
|
||||
),
|
||||
upsample_rates=(8, 8, 4),
|
||||
upsample_initial_channel=256,
|
||||
upsample_kernel_sizes=(16, 16, 8),
|
||||
# mel
|
||||
filter_length: int = 1024,
|
||||
hop_length: int = 256,
|
||||
win_length: int = 1024,
|
||||
mel_channels: int = 80,
|
||||
sample_rate: int = 22050,
|
||||
sample_bytes: int = 2,
|
||||
channels: int = 1,
|
||||
mel_fmin: float = 0.0,
|
||||
mel_fmax: Optional[float] = None,
|
||||
# model
|
||||
inter_channels: int = 192,
|
||||
hidden_channels: int = 192,
|
||||
filter_channels: int = 768,
|
||||
n_heads: int = 2,
|
||||
n_layers: int = 6,
|
||||
kernel_size: int = 3,
|
||||
p_dropout: float = 0.1,
|
||||
n_layers_q: int = 3,
|
||||
use_spectral_norm: bool = False,
|
||||
gin_channels: int = 0,
|
||||
use_sdp: bool = True,
|
||||
segment_size: int = 8192,
|
||||
# training
|
||||
dataset: Optional[List[Union[str, Path]]] = None,
|
||||
learning_rate: float = 2e-4,
|
||||
betas: Tuple[float, float] = (0.8, 0.99),
|
||||
eps: float = 1e-9,
|
||||
batch_size: int = 1,
|
||||
lr_decay: float = 0.999875,
|
||||
init_lr_ratio: float = 1.0,
|
||||
warmup_epochs: int = 0,
|
||||
c_mel: int = 45,
|
||||
c_kl: float = 1.0,
|
||||
grad_clip: Optional[float] = None,
|
||||
num_workers: int = 1,
|
||||
seed: int = 1234,
|
||||
num_test_examples: int = 5,
|
||||
validation_split: float = 0.1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
if (self.hparams.num_speakers > 1) and (self.hparams.gin_channels <= 0):
|
||||
# Default gin_channels for multi-speaker model
|
||||
self.hparams.gin_channels = 512
|
||||
|
||||
# Set up models
|
||||
self.model_g = SynthesizerTrn(
|
||||
n_vocab=self.hparams.num_symbols,
|
||||
spec_channels=self.hparams.filter_length // 2 + 1,
|
||||
segment_size=self.hparams.segment_size // self.hparams.hop_length,
|
||||
inter_channels=self.hparams.inter_channels,
|
||||
hidden_channels=self.hparams.hidden_channels,
|
||||
filter_channels=self.hparams.filter_channels,
|
||||
n_heads=self.hparams.n_heads,
|
||||
n_layers=self.hparams.n_layers,
|
||||
kernel_size=self.hparams.kernel_size,
|
||||
p_dropout=self.hparams.p_dropout,
|
||||
resblock=self.hparams.resblock,
|
||||
resblock_kernel_sizes=self.hparams.resblock_kernel_sizes,
|
||||
resblock_dilation_sizes=self.hparams.resblock_dilation_sizes,
|
||||
upsample_rates=self.hparams.upsample_rates,
|
||||
upsample_initial_channel=self.hparams.upsample_initial_channel,
|
||||
upsample_kernel_sizes=self.hparams.upsample_kernel_sizes,
|
||||
n_speakers=self.hparams.num_speakers,
|
||||
gin_channels=self.hparams.gin_channels,
|
||||
use_sdp=self.hparams.use_sdp,
|
||||
)
|
||||
self.model_d = MultiPeriodDiscriminator(
|
||||
use_spectral_norm=self.hparams.use_spectral_norm
|
||||
)
|
||||
|
||||
# Dataset splits
|
||||
self._train_dataset: Optional[Dataset] = None
|
||||
self._val_dataset: Optional[Dataset] = None
|
||||
self._test_dataset: Optional[Dataset] = None
|
||||
self._load_datasets(validation_split, num_test_examples)
|
||||
|
||||
# State kept between training optimizers
|
||||
self._y = None
|
||||
self._y_hat = None
|
||||
|
||||
def _load_datasets(self, validation_split: float, num_test_examples: int):
|
||||
full_dataset = LarynxDataset(self.hparams.dataset)
|
||||
valid_set_size = int(len(full_dataset) * validation_split)
|
||||
train_set_size = len(full_dataset) - valid_set_size - num_test_examples
|
||||
|
||||
self._train_dataset, self._test_dataset, self._val_dataset = random_split(
|
||||
full_dataset, [train_set_size, num_test_examples, valid_set_size]
|
||||
)
|
||||
|
||||
def forward(self, text, text_lengths, scales, sid=None):
|
||||
noise_scale = scales[0]
|
||||
length_scale = scales[1]
|
||||
noise_scale_w = scales[2]
|
||||
audio, *_ = self.model_g.infer(
|
||||
text,
|
||||
text_lengths,
|
||||
noise_scale=noise_scale,
|
||||
length_scale=length_scale,
|
||||
noise_scale_w=noise_scale_w,
|
||||
sid=sid,
|
||||
)
|
||||
|
||||
return audio
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self._train_dataset,
|
||||
collate_fn=UtteranceCollate(
|
||||
is_multispeaker=self.hparams.num_speakers > 1,
|
||||
segment_size=self.hparams.segment_size,
|
||||
),
|
||||
num_workers=self.hparams.num_workers,
|
||||
batch_size=self.hparams.batch_size,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(
|
||||
self._val_dataset,
|
||||
collate_fn=UtteranceCollate(
|
||||
is_multispeaker=self.hparams.num_speakers > 1,
|
||||
segment_size=self.hparams.segment_size,
|
||||
),
|
||||
num_workers=self.hparams.num_workers,
|
||||
batch_size=self.hparams.batch_size,
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(
|
||||
self._test_dataset,
|
||||
collate_fn=UtteranceCollate(
|
||||
is_multispeaker=self.hparams.num_speakers > 1,
|
||||
segment_size=self.hparams.segment_size,
|
||||
),
|
||||
num_workers=self.hparams.num_workers,
|
||||
batch_size=self.hparams.batch_size,
|
||||
)
|
||||
|
||||
def training_step(self, batch: Batch, batch_idx: int, optimizer_idx: int):
|
||||
if optimizer_idx == 0:
|
||||
return self.training_step_g(batch)
|
||||
|
||||
if optimizer_idx == 1:
|
||||
return self.training_step_d(batch)
|
||||
|
||||
def training_step_g(self, batch: Batch):
|
||||
x, x_lengths, y, _, spec, spec_lengths, speaker_ids = (
|
||||
batch.phoneme_ids,
|
||||
batch.phoneme_lengths,
|
||||
batch.audios,
|
||||
batch.audio_lengths,
|
||||
batch.spectrograms,
|
||||
batch.spectrogram_lengths,
|
||||
batch.speaker_ids if batch.speaker_ids is not None else None,
|
||||
)
|
||||
(
|
||||
y_hat,
|
||||
l_length,
|
||||
_attn,
|
||||
ids_slice,
|
||||
_x_mask,
|
||||
z_mask,
|
||||
(_z, z_p, m_p, logs_p, _m_q, logs_q),
|
||||
) = self.model_g(x, x_lengths, spec, spec_lengths, speaker_ids)
|
||||
self._y_hat = y_hat
|
||||
|
||||
mel = spec_to_mel_torch(
|
||||
spec,
|
||||
self.hparams.filter_length,
|
||||
self.hparams.mel_channels,
|
||||
self.hparams.sample_rate,
|
||||
self.hparams.mel_fmin,
|
||||
self.hparams.mel_fmax,
|
||||
)
|
||||
y_mel = slice_segments(
|
||||
mel,
|
||||
ids_slice,
|
||||
self.hparams.segment_size // self.hparams.hop_length,
|
||||
)
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.squeeze(1),
|
||||
self.hparams.filter_length,
|
||||
self.hparams.mel_channels,
|
||||
self.hparams.sample_rate,
|
||||
self.hparams.hop_length,
|
||||
self.hparams.win_length,
|
||||
self.hparams.mel_fmin,
|
||||
self.hparams.mel_fmax,
|
||||
)
|
||||
y = slice_segments(
|
||||
y,
|
||||
ids_slice * self.hparams.hop_length,
|
||||
self.hparams.segment_size,
|
||||
) # slice
|
||||
|
||||
# Save for training_step_d
|
||||
self._y = y
|
||||
|
||||
_y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.model_d(y, y_hat)
|
||||
|
||||
with autocast(self.device.type, enabled=False):
|
||||
# Generator loss
|
||||
loss_dur = torch.sum(l_length.float())
|
||||
loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.hparams.c_mel
|
||||
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * self.hparams.c_kl
|
||||
|
||||
loss_fm = feature_loss(fmap_r, fmap_g)
|
||||
loss_gen, _losses_gen = generator_loss(y_d_hat_g)
|
||||
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
|
||||
|
||||
self.log("loss_gen_all", loss_gen_all)
|
||||
|
||||
return loss_gen_all
|
||||
|
||||
def training_step_d(self, batch: Batch):
|
||||
# From training_step_g
|
||||
y = self._y
|
||||
y_hat = self._y_hat
|
||||
y_d_hat_r, y_d_hat_g, _, _ = self.model_d(y, y_hat.detach())
|
||||
|
||||
with autocast(self.device.type, enabled=False):
|
||||
# Discriminator
|
||||
loss_disc, _losses_disc_r, _losses_disc_g = discriminator_loss(
|
||||
y_d_hat_r, y_d_hat_g
|
||||
)
|
||||
loss_disc_all = loss_disc
|
||||
|
||||
self.log("loss_disc_all", loss_disc_all)
|
||||
|
||||
return loss_disc_all
|
||||
|
||||
def validation_step(self, batch: Batch, batch_idx: int):
|
||||
val_loss = self.training_step_g(batch)
|
||||
self.log("val_loss", val_loss)
|
||||
|
||||
# Generate audio examples
|
||||
for utt_idx, test_utt in enumerate(self._test_dataset):
|
||||
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
|
||||
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
|
||||
scales = [0.667, 1.0, 0.8]
|
||||
test_audio = self(text, text_lengths, scales).detach()
|
||||
|
||||
# Scale to make louder in [-1, 1]
|
||||
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))
|
||||
|
||||
tag = test_utt.text or str(utt_idx)
|
||||
self.logger.experiment.add_audio(
|
||||
tag, test_audio, sample_rate=self.hparams.sample_rate
|
||||
)
|
||||
|
||||
return val_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizers = [
|
||||
torch.optim.AdamW(
|
||||
self.model_g.parameters(),
|
||||
lr=self.hparams.learning_rate,
|
||||
betas=self.hparams.betas,
|
||||
eps=self.hparams.eps,
|
||||
),
|
||||
torch.optim.AdamW(
|
||||
self.model_d.parameters(),
|
||||
lr=self.hparams.learning_rate,
|
||||
betas=self.hparams.betas,
|
||||
eps=self.hparams.eps,
|
||||
),
|
||||
]
|
||||
schedulers = [
|
||||
torch.optim.lr_scheduler.ExponentialLR(
|
||||
optimizers[0], gamma=self.hparams.lr_decay
|
||||
),
|
||||
torch.optim.lr_scheduler.ExponentialLR(
|
||||
optimizers[1], gamma=self.hparams.lr_decay
|
||||
),
|
||||
]
|
||||
|
||||
return optimizers, schedulers
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser):
|
||||
parser = parent_parser.add_argument_group("VitsModel")
|
||||
parser.add_argument("--batch-size", type=int, required=True)
|
||||
parser.add_argument("--validation-split", type=float, default=0.1)
|
||||
parser.add_argument("--num-test-examples", type=int, default=5)
|
||||
#
|
||||
parser.add_argument("--hidden-channels", type=int, default=192)
|
||||
parser.add_argument("--inter-channels", type=int, default=192)
|
||||
parser.add_argument("--filter-channels", type=int, default=768)
|
||||
parser.add_argument("--n-layers", type=int, default=6)
|
||||
parser.add_argument("--n-heads", type=int, default=2)
|
||||
#
|
||||
return parent_parser
|
@ -0,0 +1,58 @@
|
||||
import torch
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
dr = dr.float()
|
||||
dg = dg.float()
|
||||
r_loss = torch.mean((1 - dr) ** 2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
dg = dg.float()
|
||||
l_dg = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l_dg)
|
||||
loss += l_dg
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||
"""
|
||||
z_p, logs_q: [b, h, t_t]
|
||||
m_p, logs_p: [b, h, t_t]
|
||||
"""
|
||||
z_p = z_p.float()
|
||||
logs_q = logs_q.float()
|
||||
m_p = m_p.float()
|
||||
logs_p = logs_p.float()
|
||||
z_mask = z_mask.float()
|
||||
|
||||
kl = logs_p - logs_q - 0.5
|
||||
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
||||
kl = torch.sum(kl * z_mask)
|
||||
l_kl = kl / torch.sum(z_mask)
|
||||
return l_kl
|
@ -0,0 +1,137 @@
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).type_as(y)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||
global mel_basis
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).type_as(spec)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mel_spectrogram_torch(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).type_as(y)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).type_as(y)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
@ -0,0 +1,727 @@
|
||||
import math
|
||||
import typing
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
from . import attentions, commons, modules, monotonic_align
|
||||
from .commons import get_padding, init_weights
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
filter_channels: int,
|
||||
kernel_size: int,
|
||||
p_dropout: float,
|
||||
n_flows: int = 4,
|
||||
gin_channels: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
filter_channels = in_channels # it needs to be removed from future version.
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.log_flow = modules.Log()
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(4):
|
||||
self.post_flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.post_flows.append(modules.Flip())
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
||||
x = torch.detach(x)
|
||||
x = self.pre(x)
|
||||
if g is not None:
|
||||
g = torch.detach(g)
|
||||
x = x + self.cond(g)
|
||||
x = self.convs(x, x_mask)
|
||||
x = self.proj(x) * x_mask
|
||||
|
||||
if not reverse:
|
||||
flows = self.flows
|
||||
assert w is not None
|
||||
|
||||
logdet_tot_q = 0
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_convs(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = torch.randn(w.size(0), 2, w.size(2)).type_as(x) * x_mask
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
logdet_tot_q += logdet_q
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum(
|
||||
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
||||
)
|
||||
logq = (
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
||||
- logdet_tot_q
|
||||
)
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
logdet_tot += logdet
|
||||
z = torch.cat([z0, z1], 1)
|
||||
for flow in flows:
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = (
|
||||
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
||||
- logdet_tot
|
||||
)
|
||||
return nll + logq # [b]
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = torch.randn(x.size(0), 2, x.size(2)).type_as(x) * noise_scale
|
||||
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = torch.split(z, [1, 1], 1)
|
||||
logw = z0
|
||||
return logw
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
filter_channels: int,
|
||||
kernel_size: int,
|
||||
p_dropout: float,
|
||||
gin_channels: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.conv_1 = nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(
|
||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
x = torch.detach(x)
|
||||
if g is not None:
|
||||
g = torch.detach(g)
|
||||
x = x + self.cond(g)
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.drop(x)
|
||||
x = self.proj(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_vocab: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int,
|
||||
p_dropout: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_vocab = n_vocab
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths):
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(x_lengths, x.size(2)), 1
|
||||
).type_as(x)
|
||||
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return x, m, logs, x_mask
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
dilation_rate: int,
|
||||
n_layers: int,
|
||||
n_flows: int = 4,
|
||||
gin_channels: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.flows = nn.ModuleList()
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
modules.ResidualCouplingLayer(
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=gin_channels,
|
||||
mean_only=True,
|
||||
)
|
||||
)
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
dilation_rate: int,
|
||||
n_layers: int,
|
||||
gin_channels: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.enc = modules.WN(
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(x_lengths, x.size(2)), 1
|
||||
).type_as(x)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
||||
return z, m, logs, x_mask
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
initial_channel: int,
|
||||
resblock: typing.Optional[str],
|
||||
resblock_kernel_sizes: typing.Tuple[int, ...],
|
||||
resblock_dilation_sizes: typing.Tuple[typing.Tuple[int, ...], ...],
|
||||
upsample_rates: typing.Tuple[int, ...],
|
||||
upsample_initial_channel: int,
|
||||
upsample_kernel_sizes: typing.Tuple[int, ...],
|
||||
gin_channels: int = 0,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(
|
||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||
)
|
||||
resblock_module = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(resblock_module(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
self.ups.apply(init_weights)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
period: int,
|
||||
kernel_size: int = 5,
|
||||
stride: int = 3,
|
||||
use_spectral_norm: bool = False,
|
||||
):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(
|
||||
Conv2d(
|
||||
1,
|
||||
32,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(kernel_size, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
32,
|
||||
128,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(kernel_size, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
128,
|
||||
512,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(kernel_size, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
512,
|
||||
1024,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(kernel_size, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
1024,
|
||||
1024,
|
||||
(kernel_size, 1),
|
||||
1,
|
||||
padding=(get_padding(kernel_size, 1), 0),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [
|
||||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
||||
]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_rs.append(fmap_r)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class SynthesizerTrn(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_vocab: int,
|
||||
spec_channels: int,
|
||||
segment_size: int,
|
||||
inter_channels: int,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int,
|
||||
p_dropout: float,
|
||||
resblock: str,
|
||||
resblock_kernel_sizes: typing.Tuple[int, ...],
|
||||
resblock_dilation_sizes: typing.Tuple[typing.Tuple[int, ...], ...],
|
||||
upsample_rates: typing.Tuple[int, ...],
|
||||
upsample_initial_channel: int,
|
||||
upsample_kernel_sizes: typing.Tuple[int, ...],
|
||||
n_speakers: int = 1,
|
||||
gin_channels: int = 0,
|
||||
use_sdp: bool = True,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.n_vocab = n_vocab
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.segment_size = segment_size
|
||||
self.n_speakers = n_speakers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.use_sdp = use_sdp
|
||||
|
||||
self.enc_p = TextEncoder(
|
||||
n_vocab,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
)
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.enc_q = PosteriorEncoder(
|
||||
spec_channels,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
5,
|
||||
1,
|
||||
16,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
|
||||
if use_sdp:
|
||||
self.dp = StochasticDurationPredictor(
|
||||
hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
|
||||
)
|
||||
else:
|
||||
self.dp = DurationPredictor(
|
||||
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
|
||||
)
|
||||
|
||||
if n_speakers > 1:
|
||||
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths, sid=None):
|
||||
|
||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
||||
if self.n_speakers > 1:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
g = None
|
||||
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
with torch.no_grad():
|
||||
# negative cross-entropy
|
||||
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
|
||||
neg_cent1 = torch.sum(
|
||||
-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
|
||||
) # [b, 1, t_s]
|
||||
neg_cent2 = torch.matmul(
|
||||
-0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
|
||||
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
||||
neg_cent3 = torch.matmul(
|
||||
z_p.transpose(1, 2), (m_p * s_p_sq_r)
|
||||
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
||||
neg_cent4 = torch.sum(
|
||||
-0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
|
||||
) # [b, 1, t_s]
|
||||
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
||||
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
attn = (
|
||||
monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
|
||||
.unsqueeze(1)
|
||||
.detach()
|
||||
)
|
||||
|
||||
w = attn.sum(2)
|
||||
if self.use_sdp:
|
||||
l_length = self.dp(x, x_mask, w, g=g)
|
||||
l_length = l_length / torch.sum(x_mask)
|
||||
else:
|
||||
logw_ = torch.log(w + 1e-6) * x_mask
|
||||
logw = self.dp(x, x_mask, g=g)
|
||||
l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
|
||||
x_mask
|
||||
) # for averaging
|
||||
|
||||
# expand prior
|
||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
z_slice, ids_slice = commons.rand_slice_segments(
|
||||
z, y_lengths, self.segment_size
|
||||
)
|
||||
o = self.dec(z_slice, g=g)
|
||||
return (
|
||||
o,
|
||||
l_length,
|
||||
attn,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
y_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
)
|
||||
|
||||
def infer(
|
||||
self,
|
||||
x,
|
||||
x_lengths,
|
||||
sid=None,
|
||||
noise_scale=0.667,
|
||||
length_scale=1,
|
||||
noise_scale_w=0.8,
|
||||
max_len=None,
|
||||
):
|
||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
||||
if self.n_speakers > 1:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
g = None
|
||||
|
||||
if self.use_sdp:
|
||||
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
|
||||
else:
|
||||
logw = self.dp(x, x_mask, g=g)
|
||||
w = torch.exp(logw) * x_mask * length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(y_lengths, y_lengths.max()), 1
|
||||
).type_as(x_mask)
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
attn = commons.generate_path(w_ceil, attn_mask)
|
||||
|
||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
o = self.dec((z * y_mask)[:, :, :max_len], g=g)
|
||||
|
||||
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
||||
assert self.n_speakers > 1, "n_speakers have to be larger than 1."
|
||||
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
||||
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
||||
z_p = self.flow(z, y_mask, g=g_src)
|
||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
|
||||
return o_hat, y_mask, (z, z_p, z_hat)
|
@ -0,0 +1,527 @@
|
||||
import math
|
||||
import typing
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
from .commons import fused_add_tanh_sigmoid_multiply, get_padding, init_weights
|
||||
from .transforms import piecewise_rational_quadratic_transform
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
n_layers: int,
|
||||
p_dropout: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
assert n_layers > 1, "Number of layers should be larger than 0."
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DDSConv(nn.Module):
|
||||
"""
|
||||
Dialted and Depth-Separable Convolution
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, n_layers: int, p_dropout: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(n_layers):
|
||||
dilation = kernel_size**i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
groups=channels,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
)
|
||||
)
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.n_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.drop(y)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
dilation_rate: int,
|
||||
n_layers: int,
|
||||
gin_channels: int = 0,
|
||||
p_dropout: float = 0,
|
||||
):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = (kernel_size,)
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
gin_channels, 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate**i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(
|
||||
hidden_channels,
|
||||
2 * hidden_channels,
|
||||
kernel_size,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
||||
x = (x + res_acts) * x_mask
|
||||
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
||||
def remove_weight_norm(self):
|
||||
if self.gin_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: typing.Tuple[int] = (1, 3, 5),
|
||||
):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int = 3, dilation: typing.Tuple[int] = (1, 3)
|
||||
):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class Log(nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_mask: torch.Tensor, reverse: bool = False, **kwargs
|
||||
):
|
||||
if not reverse:
|
||||
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
||||
logdet = torch.sum(-y, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = torch.exp(x) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Flip(nn.Module):
|
||||
def forward(self, x: torch.Tensor, *args, reverse: bool = False, **kwargs):
|
||||
x = torch.flip(x, [1])
|
||||
if not reverse:
|
||||
logdet = torch.zeros(x.size(0)).type_as(x)
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class ElementwiseAffine(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.m = nn.Parameter(torch.zeros(channels, 1))
|
||||
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
||||
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = self.m + torch.exp(self.logs) * x
|
||||
y = y * x_mask
|
||||
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCouplingLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
dilation_rate: int,
|
||||
n_layers: int,
|
||||
p_dropout: float = 0,
|
||||
gin_channels: int = 0,
|
||||
mean_only: bool = False,
|
||||
):
|
||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.half_channels = channels // 2
|
||||
self.mean_only = mean_only
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||
self.enc = WN(
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
p_dropout=p_dropout,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||
self.post.weight.data.zero_()
|
||||
self.post.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||
h = self.pre(x0) * x_mask
|
||||
h = self.enc(h, x_mask, g=g)
|
||||
stats = self.post(h) * x_mask
|
||||
if not self.mean_only:
|
||||
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
||||
else:
|
||||
m = stats
|
||||
logs = torch.zeros_like(m)
|
||||
|
||||
if not reverse:
|
||||
x1 = m + x1 * torch.exp(logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
logdet = torch.sum(logs, [1, 2])
|
||||
return x, logdet
|
||||
else:
|
||||
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
return x
|
||||
|
||||
|
||||
class ConvFlow(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
filter_channels: int,
|
||||
kernel_size: int,
|
||||
n_layers: int,
|
||||
num_bins: int = 10,
|
||||
tail_bound: float = 5.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.num_bins = num_bins
|
||||
self.tail_bound = tail_bound
|
||||
self.half_channels = in_channels // 2
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
||||
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
||||
self.proj = nn.Conv1d(
|
||||
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
|
||||
)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||
h = self.pre(x0)
|
||||
h = self.convs(h, x_mask, g=g)
|
||||
h = self.proj(h) * x_mask
|
||||
|
||||
b, c, t = x0.shape
|
||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||
|
||||
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
|
||||
self.filter_channels
|
||||
)
|
||||
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
||||
|
||||
x1, logabsdet = piecewise_rational_quadratic_transform(
|
||||
x1,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=reverse,
|
||||
tails="linear",
|
||||
tail_bound=self.tail_bound,
|
||||
)
|
||||
|
||||
x = torch.cat([x0, x1], 1) * x_mask
|
||||
|
||||
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
||||
if not reverse:
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
@ -0,0 +1,2 @@
|
||||
all:
|
||||
python3 setup.py build_ext --inplace
|
@ -0,0 +1,20 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .monotonic_align.core import maximum_path_c
|
||||
|
||||
|
||||
def maximum_path(neg_cent, mask):
|
||||
"""Cython optimized version.
|
||||
neg_cent: [b, t_t, t_s]
|
||||
mask: [b, t_t, t_s]
|
||||
"""
|
||||
device = neg_cent.device
|
||||
dtype = neg_cent.dtype
|
||||
neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
|
||||
path = np.zeros(neg_cent.shape, dtype=np.int32)
|
||||
|
||||
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
|
||||
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
|
||||
maximum_path_c(path, neg_cent, t_t_max, t_s_max)
|
||||
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,42 @@
|
||||
cimport cython
|
||||
from cython.parallel import prange
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
|
||||
cdef int x
|
||||
cdef int y
|
||||
cdef float v_prev
|
||||
cdef float v_cur
|
||||
cdef float tmp
|
||||
cdef int index = t_x - 1
|
||||
|
||||
for y in range(t_y):
|
||||
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[y-1, x]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[y-1, x-1]
|
||||
value[y, x] += max(v_prev, v_cur)
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[y, index] = 1
|
||||
if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
|
||||
index = index - 1
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
|
||||
cdef int b = paths.shape[0]
|
||||
cdef int i
|
||||
for i in prange(b, nogil=True):
|
||||
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
|
@ -0,0 +1,13 @@
|
||||
from distutils.core import setup
|
||||
from pathlib import Path
|
||||
|
||||
import numpy
|
||||
from Cython.Build import cythonize
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
setup(
|
||||
name="monotonic_align",
|
||||
ext_modules=cythonize(str(_DIR / "core.pyx")),
|
||||
include_dirs=[numpy.get_include()],
|
||||
)
|
@ -0,0 +1,212 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||
|
||||
|
||||
def piecewise_rational_quadratic_transform(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails=None,
|
||||
tail_bound=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
|
||||
if tails is None:
|
||||
spline_fn = rational_quadratic_spline
|
||||
spline_kwargs = {}
|
||||
else:
|
||||
spline_fn = unconstrained_rational_quadratic_spline
|
||||
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
||||
|
||||
outputs, logabsdet = spline_fn(
|
||||
inputs=inputs,
|
||||
unnormalized_widths=unnormalized_widths,
|
||||
unnormalized_heights=unnormalized_heights,
|
||||
unnormalized_derivatives=unnormalized_derivatives,
|
||||
inverse=inverse,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs
|
||||
)
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def searchsorted(bin_locations, inputs, eps=1e-6):
|
||||
# bin_locations[..., -1] += eps
|
||||
bin_locations[..., bin_locations.size(-1) - 1] += eps
|
||||
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
||||
|
||||
|
||||
def unconstrained_rational_quadratic_spline(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails="linear",
|
||||
tail_bound=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||
outside_interval_mask = ~inside_interval_mask
|
||||
|
||||
outputs = torch.zeros_like(inputs)
|
||||
logabsdet = torch.zeros_like(inputs)
|
||||
|
||||
if tails == "linear":
|
||||
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
||||
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||
unnormalized_derivatives[..., 0] = constant
|
||||
# unnormalized_derivatives[..., -1] = constant
|
||||
unnormalized_derivatives[..., unnormalized_derivatives.size(-1) - 1] = constant
|
||||
|
||||
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||
logabsdet[outside_interval_mask] = 0
|
||||
else:
|
||||
raise RuntimeError("{} tails are not implemented.".format(tails))
|
||||
|
||||
(
|
||||
outputs[inside_interval_mask],
|
||||
logabsdet[inside_interval_mask],
|
||||
) = rational_quadratic_spline(
|
||||
inputs=inputs[inside_interval_mask],
|
||||
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
||||
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
||||
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
||||
inverse=inverse,
|
||||
left=-tail_bound,
|
||||
right=tail_bound,
|
||||
bottom=-tail_bound,
|
||||
top=tail_bound,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
)
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def rational_quadratic_spline(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
left=0.0,
|
||||
right=1.0,
|
||||
bottom=0.0,
|
||||
top=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
# if torch.min(inputs) < left or torch.max(inputs) > right:
|
||||
# raise ValueError("Input to a transform is not within its domain")
|
||||
|
||||
num_bins = unnormalized_widths.shape[-1]
|
||||
|
||||
# if min_bin_width * num_bins > 1.0:
|
||||
# raise ValueError("Minimal bin width too large for the number of bins")
|
||||
# if min_bin_height * num_bins > 1.0:
|
||||
# raise ValueError("Minimal bin height too large for the number of bins")
|
||||
|
||||
widths = F.softmax(unnormalized_widths, dim=-1)
|
||||
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||
cumwidths = torch.cumsum(widths, dim=-1)
|
||||
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
||||
cumwidths = (right - left) * cumwidths + left
|
||||
cumwidths[..., 0] = left
|
||||
# cumwidths[..., -1] = right
|
||||
cumwidths[..., cumwidths.size(-1) - 1] = right
|
||||
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||
|
||||
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||
|
||||
heights = F.softmax(unnormalized_heights, dim=-1)
|
||||
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||
cumheights = torch.cumsum(heights, dim=-1)
|
||||
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
||||
cumheights = (top - bottom) * cumheights + bottom
|
||||
cumheights[..., 0] = bottom
|
||||
# cumheights[..., -1] = top
|
||||
cumheights[..., cumheights.size(-1) - 1] = top
|
||||
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||
|
||||
if inverse:
|
||||
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
||||
else:
|
||||
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
||||
|
||||
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
||||
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
||||
delta = heights / widths
|
||||
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
||||
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
if inverse:
|
||||
a = (inputs - input_cumheights) * (
|
||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
) + input_heights * (input_delta - input_derivatives)
|
||||
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
)
|
||||
c = -input_delta * (inputs - input_cumheights)
|
||||
|
||||
discriminant = b.pow(2) - 4 * a * c
|
||||
assert (discriminant >= 0).all(), discriminant
|
||||
|
||||
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
||||
outputs = root * input_bin_widths + input_cumwidths
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
)
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * root.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - root).pow(2)
|
||||
)
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, -logabsdet
|
||||
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (
|
||||
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
||||
)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * theta.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - theta).pow(2)
|
||||
)
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, logabsdet
|
@ -0,0 +1,16 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_gpu(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.contiguous().cuda(non_blocking=True)
|
||||
|
||||
|
||||
def audio_float_to_int16(
|
||||
audio: np.ndarray, max_wav_value: float = 32767.0
|
||||
) -> np.ndarray:
|
||||
"""Normalize audio and convert to int16 range"""
|
||||
audio_norm = audio * (max_wav_value / max(0.01, np.max(np.abs(audio))))
|
||||
audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value)
|
||||
audio_norm = audio_norm.astype("int16")
|
||||
return audio_norm
|
@ -0,0 +1,172 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from subprocess import Popen
|
||||
from typing import List, Optional
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from dataclasses_json import DataClassJsonMixin
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetUtterance(DataClassJsonMixin):
|
||||
id: str
|
||||
text: Optional[str] = None
|
||||
phonemes: Optional[List[str]] = None
|
||||
phoneme_ids: Optional[List[int]] = None
|
||||
audio_path: Optional[Path] = None
|
||||
audio_norm_path: Optional[Path] = None
|
||||
mel_spec_path: Optional[Path] = None
|
||||
speaker: Optional[str] = None
|
||||
speaker_id: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._original_json: Optional[str] = None
|
||||
|
||||
@property
|
||||
def original_json(self) -> str:
|
||||
if self._original_json is None:
|
||||
self._original_json = self.to_json(ensure_ascii=False)
|
||||
|
||||
return self._original_json
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingUtterance:
|
||||
id: str
|
||||
phoneme_ids: Tensor
|
||||
audio_norm: Tensor
|
||||
mel_spec: Tensor
|
||||
speaker_id: Optional[Tensor] = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class UtteranceLoadingContext:
|
||||
cache_dir: Path
|
||||
is_multispeaker: bool
|
||||
phonemize: Optional[Popen] = None
|
||||
phonemes2ids: Optional[Popen] = None
|
||||
speaker2id: Optional[Popen] = None
|
||||
audio2norm: Optional[Popen] = None
|
||||
audio2spec: Optional[Popen] = None
|
||||
|
||||
|
||||
def load_utterance(
|
||||
utterance_json: str, context: UtteranceLoadingContext
|
||||
) -> TrainingUtterance:
|
||||
data_utterance = DatasetUtterance.from_json(utterance_json)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
data_utterance._original_json = utterance_json
|
||||
|
||||
# Requirements:
|
||||
# 1. phoneme ids
|
||||
# 2. audio norm
|
||||
# 3. mel spec
|
||||
# 4. speaker id (if multispeaker)
|
||||
|
||||
# 1. phoneme ids
|
||||
if data_utterance.phoneme_ids is None:
|
||||
_load_phoneme_ids(data_utterance, context)
|
||||
|
||||
# 2. audio norm
|
||||
if (data_utterance.audio_norm_path is None) or (
|
||||
not data_utterance.audio_norm_path.exists()
|
||||
):
|
||||
_load_audio_norm(data_utterance, context)
|
||||
|
||||
# 3. mel spec
|
||||
if (data_utterance.mel_spec_path is None) or (
|
||||
not data_utterance.mel_spec_path.exists()
|
||||
):
|
||||
_load_mel_spec(data_utterance, context)
|
||||
|
||||
# 4. speaker id
|
||||
if context.is_multispeaker:
|
||||
if data_utterance.speaker_id is None:
|
||||
_load_speaker_id(data_utterance, context)
|
||||
|
||||
# Convert to training utterance
|
||||
assert data_utterance.phoneme_ids is not None
|
||||
assert data_utterance.audio_norm_path is not None
|
||||
assert data_utterance.mel_spec_path is not None
|
||||
|
||||
if context.is_multispeaker:
|
||||
assert data_utterance.speaker_id is not None
|
||||
|
||||
train_utterance = TrainingUtterance(
|
||||
id=data_utterance.id,
|
||||
phoneme_ids=torch.LongTensor(data_utterance.phoneme_ids),
|
||||
audio_norm=torch.load(data_utterance.audio_norm_path),
|
||||
mel_spec=torch.load(data_utterance.mel_spec_path),
|
||||
speaker_id=None
|
||||
if data_utterance.speaker_id is None
|
||||
else torch.LongTensor(data_utterance.speaker_id),
|
||||
)
|
||||
|
||||
return train_utterance
|
||||
|
||||
|
||||
def _load_phoneme_ids(
|
||||
data_utterance: DatasetUtterance, context: UtteranceLoadingContext
|
||||
):
|
||||
if data_utterance.phonemes is None:
|
||||
# Need phonemes first
|
||||
_load_phonemes(data_utterance, context)
|
||||
|
||||
assert (
|
||||
data_utterance.phonemes is not None
|
||||
), f"phonemes is required for phoneme ids: {data_utterance}"
|
||||
|
||||
assert (
|
||||
context.phonemes2ids is not None
|
||||
), f"phonemes2ids program is required for phoneme ids: {data_utterance}"
|
||||
|
||||
assert context.phonemes2ids.stdin is not None
|
||||
assert context.phonemes2ids.stdout is not None
|
||||
|
||||
# JSON in, JSON out
|
||||
print(data_utterance.original_json, file=context.phonemes2ids.stdin, flush=True)
|
||||
result_json = context.phonemes2ids.stdout.readline()
|
||||
result_dict = json.loads(result_json)
|
||||
|
||||
# Update utterance
|
||||
data_utterance.phoneme_ids = result_dict["phoneme_ids"]
|
||||
data_utterance._original_json = result_json
|
||||
|
||||
|
||||
def _load_phonemes(data_utterance: DatasetUtterance, context: UtteranceLoadingContext):
|
||||
assert (
|
||||
data_utterance.text is not None
|
||||
), f"text is required for phonemes: {data_utterance}"
|
||||
|
||||
assert (
|
||||
context.phonemize is not None
|
||||
), f"phonemize program is required for phonemes: {data_utterance}"
|
||||
|
||||
assert context.phonemize.stdin is not None
|
||||
assert context.phonemize.stdout is not None
|
||||
|
||||
# JSON in, JSON out
|
||||
print(data_utterance.original_json, file=context.phonemize.stdin, flush=True)
|
||||
result_json = context.phonemize.stdout.readline()
|
||||
result_dict = json.loads(result_json)
|
||||
|
||||
# Update utterance
|
||||
data_utterance.phonemes = result_dict["phoneme"]
|
||||
data_utterance._original_json = result_json
|
||||
|
||||
|
||||
def _load_audio_norm(
|
||||
data_utterance: DatasetUtterance, context: UtteranceLoadingContext
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def _load_mel_spec(data_utterance: DatasetUtterance, context: UtteranceLoadingContext):
|
||||
pass
|
@ -0,0 +1,860 @@
|
||||
"""
|
||||
Module to read / write wav files using NumPy arrays
|
||||
|
||||
Functions
|
||||
---------
|
||||
`read`: Return the sample rate (in samples/sec) and data from a WAV file.
|
||||
|
||||
`write`: Write a NumPy array as a WAV file.
|
||||
|
||||
"""
|
||||
import io
|
||||
import struct
|
||||
import sys
|
||||
import warnings
|
||||
from enum import IntEnum
|
||||
|
||||
import numpy
|
||||
|
||||
__all__ = ["WavFileWarning", "read", "write"]
|
||||
|
||||
|
||||
class WavFileWarning(UserWarning):
|
||||
pass
|
||||
|
||||
|
||||
class WAVE_FORMAT(IntEnum):
|
||||
"""
|
||||
WAVE form wFormatTag IDs
|
||||
|
||||
Complete list is in mmreg.h in Windows 10 SDK. ALAC and OPUS are the
|
||||
newest additions, in v10.0.14393 2016-07
|
||||
"""
|
||||
|
||||
UNKNOWN = 0x0000
|
||||
PCM = 0x0001
|
||||
ADPCM = 0x0002
|
||||
IEEE_FLOAT = 0x0003
|
||||
VSELP = 0x0004
|
||||
IBM_CVSD = 0x0005
|
||||
ALAW = 0x0006
|
||||
MULAW = 0x0007
|
||||
DTS = 0x0008
|
||||
DRM = 0x0009
|
||||
WMAVOICE9 = 0x000A
|
||||
WMAVOICE10 = 0x000B
|
||||
OKI_ADPCM = 0x0010
|
||||
DVI_ADPCM = 0x0011
|
||||
IMA_ADPCM = 0x0011 # Duplicate
|
||||
MEDIASPACE_ADPCM = 0x0012
|
||||
SIERRA_ADPCM = 0x0013
|
||||
G723_ADPCM = 0x0014
|
||||
DIGISTD = 0x0015
|
||||
DIGIFIX = 0x0016
|
||||
DIALOGIC_OKI_ADPCM = 0x0017
|
||||
MEDIAVISION_ADPCM = 0x0018
|
||||
CU_CODEC = 0x0019
|
||||
HP_DYN_VOICE = 0x001A
|
||||
YAMAHA_ADPCM = 0x0020
|
||||
SONARC = 0x0021
|
||||
DSPGROUP_TRUESPEECH = 0x0022
|
||||
ECHOSC1 = 0x0023
|
||||
AUDIOFILE_AF36 = 0x0024
|
||||
APTX = 0x0025
|
||||
AUDIOFILE_AF10 = 0x0026
|
||||
PROSODY_1612 = 0x0027
|
||||
LRC = 0x0028
|
||||
DOLBY_AC2 = 0x0030
|
||||
GSM610 = 0x0031
|
||||
MSNAUDIO = 0x0032
|
||||
ANTEX_ADPCME = 0x0033
|
||||
CONTROL_RES_VQLPC = 0x0034
|
||||
DIGIREAL = 0x0035
|
||||
DIGIADPCM = 0x0036
|
||||
CONTROL_RES_CR10 = 0x0037
|
||||
NMS_VBXADPCM = 0x0038
|
||||
CS_IMAADPCM = 0x0039
|
||||
ECHOSC3 = 0x003A
|
||||
ROCKWELL_ADPCM = 0x003B
|
||||
ROCKWELL_DIGITALK = 0x003C
|
||||
XEBEC = 0x003D
|
||||
G721_ADPCM = 0x0040
|
||||
G728_CELP = 0x0041
|
||||
MSG723 = 0x0042
|
||||
INTEL_G723_1 = 0x0043
|
||||
INTEL_G729 = 0x0044
|
||||
SHARP_G726 = 0x0045
|
||||
MPEG = 0x0050
|
||||
RT24 = 0x0052
|
||||
PAC = 0x0053
|
||||
MPEGLAYER3 = 0x0055
|
||||
LUCENT_G723 = 0x0059
|
||||
CIRRUS = 0x0060
|
||||
ESPCM = 0x0061
|
||||
VOXWARE = 0x0062
|
||||
CANOPUS_ATRAC = 0x0063
|
||||
G726_ADPCM = 0x0064
|
||||
G722_ADPCM = 0x0065
|
||||
DSAT = 0x0066
|
||||
DSAT_DISPLAY = 0x0067
|
||||
VOXWARE_BYTE_ALIGNED = 0x0069
|
||||
VOXWARE_AC8 = 0x0070
|
||||
VOXWARE_AC10 = 0x0071
|
||||
VOXWARE_AC16 = 0x0072
|
||||
VOXWARE_AC20 = 0x0073
|
||||
VOXWARE_RT24 = 0x0074
|
||||
VOXWARE_RT29 = 0x0075
|
||||
VOXWARE_RT29HW = 0x0076
|
||||
VOXWARE_VR12 = 0x0077
|
||||
VOXWARE_VR18 = 0x0078
|
||||
VOXWARE_TQ40 = 0x0079
|
||||
VOXWARE_SC3 = 0x007A
|
||||
VOXWARE_SC3_1 = 0x007B
|
||||
SOFTSOUND = 0x0080
|
||||
VOXWARE_TQ60 = 0x0081
|
||||
MSRT24 = 0x0082
|
||||
G729A = 0x0083
|
||||
MVI_MVI2 = 0x0084
|
||||
DF_G726 = 0x0085
|
||||
DF_GSM610 = 0x0086
|
||||
ISIAUDIO = 0x0088
|
||||
ONLIVE = 0x0089
|
||||
MULTITUDE_FT_SX20 = 0x008A
|
||||
INFOCOM_ITS_G721_ADPCM = 0x008B
|
||||
CONVEDIA_G729 = 0x008C
|
||||
CONGRUENCY = 0x008D
|
||||
SBC24 = 0x0091
|
||||
DOLBY_AC3_SPDIF = 0x0092
|
||||
MEDIASONIC_G723 = 0x0093
|
||||
PROSODY_8KBPS = 0x0094
|
||||
ZYXEL_ADPCM = 0x0097
|
||||
PHILIPS_LPCBB = 0x0098
|
||||
PACKED = 0x0099
|
||||
MALDEN_PHONYTALK = 0x00A0
|
||||
RACAL_RECORDER_GSM = 0x00A1
|
||||
RACAL_RECORDER_G720_A = 0x00A2
|
||||
RACAL_RECORDER_G723_1 = 0x00A3
|
||||
RACAL_RECORDER_TETRA_ACELP = 0x00A4
|
||||
NEC_AAC = 0x00B0
|
||||
RAW_AAC1 = 0x00FF
|
||||
RHETOREX_ADPCM = 0x0100
|
||||
IRAT = 0x0101
|
||||
VIVO_G723 = 0x0111
|
||||
VIVO_SIREN = 0x0112
|
||||
PHILIPS_CELP = 0x0120
|
||||
PHILIPS_GRUNDIG = 0x0121
|
||||
DIGITAL_G723 = 0x0123
|
||||
SANYO_LD_ADPCM = 0x0125
|
||||
SIPROLAB_ACEPLNET = 0x0130
|
||||
SIPROLAB_ACELP4800 = 0x0131
|
||||
SIPROLAB_ACELP8V3 = 0x0132
|
||||
SIPROLAB_G729 = 0x0133
|
||||
SIPROLAB_G729A = 0x0134
|
||||
SIPROLAB_KELVIN = 0x0135
|
||||
VOICEAGE_AMR = 0x0136
|
||||
G726ADPCM = 0x0140
|
||||
DICTAPHONE_CELP68 = 0x0141
|
||||
DICTAPHONE_CELP54 = 0x0142
|
||||
QUALCOMM_PUREVOICE = 0x0150
|
||||
QUALCOMM_HALFRATE = 0x0151
|
||||
TUBGSM = 0x0155
|
||||
MSAUDIO1 = 0x0160
|
||||
WMAUDIO2 = 0x0161
|
||||
WMAUDIO3 = 0x0162
|
||||
WMAUDIO_LOSSLESS = 0x0163
|
||||
WMASPDIF = 0x0164
|
||||
UNISYS_NAP_ADPCM = 0x0170
|
||||
UNISYS_NAP_ULAW = 0x0171
|
||||
UNISYS_NAP_ALAW = 0x0172
|
||||
UNISYS_NAP_16K = 0x0173
|
||||
SYCOM_ACM_SYC008 = 0x0174
|
||||
SYCOM_ACM_SYC701_G726L = 0x0175
|
||||
SYCOM_ACM_SYC701_CELP54 = 0x0176
|
||||
SYCOM_ACM_SYC701_CELP68 = 0x0177
|
||||
KNOWLEDGE_ADVENTURE_ADPCM = 0x0178
|
||||
FRAUNHOFER_IIS_MPEG2_AAC = 0x0180
|
||||
DTS_DS = 0x0190
|
||||
CREATIVE_ADPCM = 0x0200
|
||||
CREATIVE_FASTSPEECH8 = 0x0202
|
||||
CREATIVE_FASTSPEECH10 = 0x0203
|
||||
UHER_ADPCM = 0x0210
|
||||
ULEAD_DV_AUDIO = 0x0215
|
||||
ULEAD_DV_AUDIO_1 = 0x0216
|
||||
QUARTERDECK = 0x0220
|
||||
ILINK_VC = 0x0230
|
||||
RAW_SPORT = 0x0240
|
||||
ESST_AC3 = 0x0241
|
||||
GENERIC_PASSTHRU = 0x0249
|
||||
IPI_HSX = 0x0250
|
||||
IPI_RPELP = 0x0251
|
||||
CS2 = 0x0260
|
||||
SONY_SCX = 0x0270
|
||||
SONY_SCY = 0x0271
|
||||
SONY_ATRAC3 = 0x0272
|
||||
SONY_SPC = 0x0273
|
||||
TELUM_AUDIO = 0x0280
|
||||
TELUM_IA_AUDIO = 0x0281
|
||||
NORCOM_VOICE_SYSTEMS_ADPCM = 0x0285
|
||||
FM_TOWNS_SND = 0x0300
|
||||
MICRONAS = 0x0350
|
||||
MICRONAS_CELP833 = 0x0351
|
||||
BTV_DIGITAL = 0x0400
|
||||
INTEL_MUSIC_CODER = 0x0401
|
||||
INDEO_AUDIO = 0x0402
|
||||
QDESIGN_MUSIC = 0x0450
|
||||
ON2_VP7_AUDIO = 0x0500
|
||||
ON2_VP6_AUDIO = 0x0501
|
||||
VME_VMPCM = 0x0680
|
||||
TPC = 0x0681
|
||||
LIGHTWAVE_LOSSLESS = 0x08AE
|
||||
OLIGSM = 0x1000
|
||||
OLIADPCM = 0x1001
|
||||
OLICELP = 0x1002
|
||||
OLISBC = 0x1003
|
||||
OLIOPR = 0x1004
|
||||
LH_CODEC = 0x1100
|
||||
LH_CODEC_CELP = 0x1101
|
||||
LH_CODEC_SBC8 = 0x1102
|
||||
LH_CODEC_SBC12 = 0x1103
|
||||
LH_CODEC_SBC16 = 0x1104
|
||||
NORRIS = 0x1400
|
||||
ISIAUDIO_2 = 0x1401
|
||||
SOUNDSPACE_MUSICOMPRESS = 0x1500
|
||||
MPEG_ADTS_AAC = 0x1600
|
||||
MPEG_RAW_AAC = 0x1601
|
||||
MPEG_LOAS = 0x1602
|
||||
NOKIA_MPEG_ADTS_AAC = 0x1608
|
||||
NOKIA_MPEG_RAW_AAC = 0x1609
|
||||
VODAFONE_MPEG_ADTS_AAC = 0x160A
|
||||
VODAFONE_MPEG_RAW_AAC = 0x160B
|
||||
MPEG_HEAAC = 0x1610
|
||||
VOXWARE_RT24_SPEECH = 0x181C
|
||||
SONICFOUNDRY_LOSSLESS = 0x1971
|
||||
INNINGS_TELECOM_ADPCM = 0x1979
|
||||
LUCENT_SX8300P = 0x1C07
|
||||
LUCENT_SX5363S = 0x1C0C
|
||||
CUSEEME = 0x1F03
|
||||
NTCSOFT_ALF2CM_ACM = 0x1FC4
|
||||
DVM = 0x2000
|
||||
DTS2 = 0x2001
|
||||
MAKEAVIS = 0x3313
|
||||
DIVIO_MPEG4_AAC = 0x4143
|
||||
NOKIA_ADAPTIVE_MULTIRATE = 0x4201
|
||||
DIVIO_G726 = 0x4243
|
||||
LEAD_SPEECH = 0x434C
|
||||
LEAD_VORBIS = 0x564C
|
||||
WAVPACK_AUDIO = 0x5756
|
||||
OGG_VORBIS_MODE_1 = 0x674F
|
||||
OGG_VORBIS_MODE_2 = 0x6750
|
||||
OGG_VORBIS_MODE_3 = 0x6751
|
||||
OGG_VORBIS_MODE_1_PLUS = 0x676F
|
||||
OGG_VORBIS_MODE_2_PLUS = 0x6770
|
||||
OGG_VORBIS_MODE_3_PLUS = 0x6771
|
||||
ALAC = 0x6C61
|
||||
_3COM_NBX = 0x7000 # Can't have leading digit
|
||||
OPUS = 0x704F
|
||||
FAAD_AAC = 0x706D
|
||||
AMR_NB = 0x7361
|
||||
AMR_WB = 0x7362
|
||||
AMR_WP = 0x7363
|
||||
GSM_AMR_CBR = 0x7A21
|
||||
GSM_AMR_VBR_SID = 0x7A22
|
||||
COMVERSE_INFOSYS_G723_1 = 0xA100
|
||||
COMVERSE_INFOSYS_AVQSBC = 0xA101
|
||||
COMVERSE_INFOSYS_SBC = 0xA102
|
||||
SYMBOL_G729_A = 0xA103
|
||||
VOICEAGE_AMR_WB = 0xA104
|
||||
INGENIENT_G726 = 0xA105
|
||||
MPEG4_AAC = 0xA106
|
||||
ENCORE_G726 = 0xA107
|
||||
ZOLL_ASAO = 0xA108
|
||||
SPEEX_VOICE = 0xA109
|
||||
VIANIX_MASC = 0xA10A
|
||||
WM9_SPECTRUM_ANALYZER = 0xA10B
|
||||
WMF_SPECTRUM_ANAYZER = 0xA10C
|
||||
GSM_610 = 0xA10D
|
||||
GSM_620 = 0xA10E
|
||||
GSM_660 = 0xA10F
|
||||
GSM_690 = 0xA110
|
||||
GSM_ADAPTIVE_MULTIRATE_WB = 0xA111
|
||||
POLYCOM_G722 = 0xA112
|
||||
POLYCOM_G728 = 0xA113
|
||||
POLYCOM_G729_A = 0xA114
|
||||
POLYCOM_SIREN = 0xA115
|
||||
GLOBAL_IP_ILBC = 0xA116
|
||||
RADIOTIME_TIME_SHIFT_RADIO = 0xA117
|
||||
NICE_ACA = 0xA118
|
||||
NICE_ADPCM = 0xA119
|
||||
VOCORD_G721 = 0xA11A
|
||||
VOCORD_G726 = 0xA11B
|
||||
VOCORD_G722_1 = 0xA11C
|
||||
VOCORD_G728 = 0xA11D
|
||||
VOCORD_G729 = 0xA11E
|
||||
VOCORD_G729_A = 0xA11F
|
||||
VOCORD_G723_1 = 0xA120
|
||||
VOCORD_LBC = 0xA121
|
||||
NICE_G728 = 0xA122
|
||||
FRACE_TELECOM_G729 = 0xA123
|
||||
CODIAN = 0xA124
|
||||
FLAC = 0xF1AC
|
||||
EXTENSIBLE = 0xFFFE
|
||||
DEVELOPMENT = 0xFFFF
|
||||
|
||||
|
||||
KNOWN_WAVE_FORMATS = {WAVE_FORMAT.PCM, WAVE_FORMAT.IEEE_FLOAT}
|
||||
|
||||
|
||||
def _raise_bad_format(format_tag):
|
||||
try:
|
||||
format_name = WAVE_FORMAT(format_tag).name
|
||||
except ValueError:
|
||||
format_name = f"{format_tag:#06x}"
|
||||
raise ValueError(
|
||||
f"Unknown wave file format: {format_name}. Supported "
|
||||
"formats: " + ", ".join(x.name for x in KNOWN_WAVE_FORMATS)
|
||||
)
|
||||
|
||||
|
||||
def _read_fmt_chunk(fid, is_big_endian):
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
size : int
|
||||
size of format subchunk in bytes (minus 8 for "fmt " and itself)
|
||||
format_tag : int
|
||||
PCM, float, or compressed format
|
||||
channels : int
|
||||
number of channels
|
||||
fs : int
|
||||
sampling frequency in samples per second
|
||||
bytes_per_second : int
|
||||
overall byte rate for the file
|
||||
block_align : int
|
||||
bytes per sample, including all channels
|
||||
bit_depth : int
|
||||
bits per sample
|
||||
|
||||
Notes
|
||||
-----
|
||||
Assumes file pointer is immediately after the 'fmt ' id
|
||||
"""
|
||||
if is_big_endian:
|
||||
fmt = ">"
|
||||
else:
|
||||
fmt = "<"
|
||||
|
||||
size = struct.unpack(fmt + "I", fid.read(4))[0]
|
||||
|
||||
if size < 16:
|
||||
raise ValueError("Binary structure of wave file is not compliant")
|
||||
|
||||
res = struct.unpack(fmt + "HHIIHH", fid.read(16))
|
||||
bytes_read = 16
|
||||
|
||||
format_tag, channels, fs, bytes_per_second, block_align, bit_depth = res
|
||||
|
||||
if format_tag == WAVE_FORMAT.EXTENSIBLE and size >= (16 + 2):
|
||||
ext_chunk_size = struct.unpack(fmt + "H", fid.read(2))[0]
|
||||
bytes_read += 2
|
||||
if ext_chunk_size >= 22:
|
||||
extensible_chunk_data = fid.read(22)
|
||||
bytes_read += 22
|
||||
raw_guid = extensible_chunk_data[2 + 4 : 2 + 4 + 16]
|
||||
# GUID template {XXXXXXXX-0000-0010-8000-00AA00389B71} (RFC-2361)
|
||||
# MS GUID byte order: first three groups are native byte order,
|
||||
# rest is Big Endian
|
||||
if is_big_endian:
|
||||
tail = b"\x00\x00\x00\x10\x80\x00\x00\xAA\x00\x38\x9B\x71"
|
||||
else:
|
||||
tail = b"\x00\x00\x10\x00\x80\x00\x00\xAA\x00\x38\x9B\x71"
|
||||
if raw_guid.endswith(tail):
|
||||
format_tag = struct.unpack(fmt + "I", raw_guid[:4])[0]
|
||||
else:
|
||||
raise ValueError("Binary structure of wave file is not compliant")
|
||||
|
||||
if format_tag not in KNOWN_WAVE_FORMATS:
|
||||
_raise_bad_format(format_tag)
|
||||
|
||||
# move file pointer to next chunk
|
||||
if size > bytes_read:
|
||||
fid.read(size - bytes_read)
|
||||
|
||||
# fmt should always be 16, 18 or 40, but handle it just in case
|
||||
_handle_pad_byte(fid, size)
|
||||
|
||||
return (size, format_tag, channels, fs, bytes_per_second, block_align, bit_depth)
|
||||
|
||||
|
||||
def _read_data_chunk(
|
||||
fid, format_tag, channels, bit_depth, is_big_endian, block_align, mmap=False
|
||||
):
|
||||
"""
|
||||
Notes
|
||||
-----
|
||||
Assumes file pointer is immediately after the 'data' id
|
||||
|
||||
It's possible to not use all available bits in a container, or to store
|
||||
samples in a container bigger than necessary, so bytes_per_sample uses
|
||||
the actual reported container size (nBlockAlign / nChannels). Real-world
|
||||
examples:
|
||||
|
||||
Adobe Audition's "24-bit packed int (type 1, 20-bit)"
|
||||
|
||||
nChannels = 2, nBlockAlign = 6, wBitsPerSample = 20
|
||||
|
||||
http://www-mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/Samples/AFsp/M1F1-int12-AFsp.wav
|
||||
is:
|
||||
|
||||
nChannels = 2, nBlockAlign = 4, wBitsPerSample = 12
|
||||
|
||||
http://www-mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/Docs/multichaudP.pdf
|
||||
gives an example of:
|
||||
|
||||
nChannels = 2, nBlockAlign = 8, wBitsPerSample = 20
|
||||
"""
|
||||
if is_big_endian:
|
||||
fmt = ">"
|
||||
else:
|
||||
fmt = "<"
|
||||
|
||||
# Size of the data subchunk in bytes
|
||||
size = struct.unpack(fmt + "I", fid.read(4))[0]
|
||||
|
||||
# Number of bytes per sample (sample container size)
|
||||
bytes_per_sample = block_align // channels
|
||||
n_samples = size // bytes_per_sample
|
||||
|
||||
if format_tag == WAVE_FORMAT.PCM:
|
||||
if 1 <= bit_depth <= 8:
|
||||
dtype = "u1" # WAV of 8-bit integer or less are unsigned
|
||||
elif bytes_per_sample in {3, 5, 6, 7}:
|
||||
# No compatible dtype. Load as raw bytes for reshaping later.
|
||||
dtype = "V1"
|
||||
elif bit_depth <= 64:
|
||||
# Remaining bit depths can map directly to signed numpy dtypes
|
||||
dtype = f"{fmt}i{bytes_per_sample}"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported bit depth: the WAV file "
|
||||
f"has {bit_depth}-bit integer data."
|
||||
)
|
||||
elif format_tag == WAVE_FORMAT.IEEE_FLOAT:
|
||||
if bit_depth in {32, 64}:
|
||||
dtype = f"{fmt}f{bytes_per_sample}"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported bit depth: the WAV file "
|
||||
f"has {bit_depth}-bit floating-point data."
|
||||
)
|
||||
else:
|
||||
_raise_bad_format(format_tag)
|
||||
|
||||
start = fid.tell()
|
||||
if not mmap:
|
||||
try:
|
||||
count = size if dtype == "V1" else n_samples
|
||||
data = numpy.fromfile(fid, dtype=dtype, count=count)
|
||||
except io.UnsupportedOperation: # not a C-like file
|
||||
fid.seek(start, 0) # just in case it seeked, though it shouldn't
|
||||
data = numpy.frombuffer(fid.read(size), dtype=dtype)
|
||||
|
||||
if dtype == "V1":
|
||||
# Rearrange raw bytes into smallest compatible numpy dtype
|
||||
dt = f"{fmt}i4" if bytes_per_sample == 3 else f"{fmt}i8"
|
||||
a = numpy.zeros(
|
||||
(len(data) // bytes_per_sample, numpy.dtype(dt).itemsize), dtype="V1"
|
||||
)
|
||||
if is_big_endian:
|
||||
a[:, :bytes_per_sample] = data.reshape((-1, bytes_per_sample))
|
||||
else:
|
||||
a[:, -bytes_per_sample:] = data.reshape((-1, bytes_per_sample))
|
||||
data = a.view(dt).reshape(a.shape[:-1])
|
||||
else:
|
||||
if bytes_per_sample in {1, 2, 4, 8}:
|
||||
start = fid.tell()
|
||||
data = numpy.memmap(
|
||||
fid, dtype=dtype, mode="c", offset=start, shape=(n_samples,)
|
||||
)
|
||||
fid.seek(start + size)
|
||||
else:
|
||||
raise ValueError(
|
||||
"mmap=True not compatible with "
|
||||
f"{bytes_per_sample}-byte container size."
|
||||
)
|
||||
|
||||
_handle_pad_byte(fid, size)
|
||||
|
||||
if channels > 1:
|
||||
data = data.reshape(-1, channels)
|
||||
return data
|
||||
|
||||
|
||||
def _skip_unknown_chunk(fid, is_big_endian):
|
||||
if is_big_endian:
|
||||
fmt = ">I"
|
||||
else:
|
||||
fmt = "<I"
|
||||
|
||||
data = fid.read(4)
|
||||
# call unpack() and seek() only if we have really read data from file
|
||||
# otherwise empty read at the end of the file would trigger
|
||||
# unnecessary exception at unpack() call
|
||||
# in case data equals somehow to 0, there is no need for seek() anyway
|
||||
if data:
|
||||
size = struct.unpack(fmt, data)[0]
|
||||
fid.seek(size, 1)
|
||||
_handle_pad_byte(fid, size)
|
||||
|
||||
|
||||
def _read_riff_chunk(fid):
|
||||
str1 = fid.read(4) # File signature
|
||||
if str1 == b"RIFF":
|
||||
is_big_endian = False
|
||||
fmt = "<I"
|
||||
elif str1 == b"RIFX":
|
||||
is_big_endian = True
|
||||
fmt = ">I"
|
||||
else:
|
||||
# There are also .wav files with "FFIR" or "XFIR" signatures?
|
||||
raise ValueError(
|
||||
f"File format {repr(str1)} not understood. Only "
|
||||
"'RIFF' and 'RIFX' supported."
|
||||
)
|
||||
|
||||
# Size of entire file
|
||||
file_size = struct.unpack(fmt, fid.read(4))[0] + 8
|
||||
|
||||
str2 = fid.read(4)
|
||||
if str2 != b"WAVE":
|
||||
raise ValueError(f"Not a WAV file. RIFF form type is {repr(str2)}.")
|
||||
|
||||
return file_size, is_big_endian
|
||||
|
||||
|
||||
def _handle_pad_byte(fid, size):
|
||||
# "If the chunk size is an odd number of bytes, a pad byte with value zero
|
||||
# is written after ckData." So we need to seek past this after each chunk.
|
||||
if size % 2:
|
||||
fid.seek(1, 1)
|
||||
|
||||
|
||||
def read(filename, mmap=False):
|
||||
"""
|
||||
Open a WAV file.
|
||||
|
||||
Return the sample rate (in samples/sec) and data from an LPCM WAV file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : string or open file handle
|
||||
Input WAV file.
|
||||
mmap : bool, optional
|
||||
Whether to read data as memory-mapped (default: False). Not compatible
|
||||
with some bit depths; see Notes. Only to be used on real files.
|
||||
|
||||
.. versionadded:: 0.12.0
|
||||
|
||||
Returns
|
||||
-------
|
||||
rate : int
|
||||
Sample rate of WAV file.
|
||||
data : numpy array
|
||||
Data read from WAV file. Data-type is determined from the file;
|
||||
see Notes. Data is 1-D for 1-channel WAV, or 2-D of shape
|
||||
(Nsamples, Nchannels) otherwise. If a file-like input without a
|
||||
C-like file descriptor (e.g., :class:`python:io.BytesIO`) is
|
||||
passed, this will not be writeable.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Common data types: [1]_
|
||||
|
||||
===================== =========== =========== =============
|
||||
WAV format Min Max NumPy dtype
|
||||
===================== =========== =========== =============
|
||||
32-bit floating-point -1.0 +1.0 float32
|
||||
32-bit integer PCM -2147483648 +2147483647 int32
|
||||
24-bit integer PCM -2147483648 +2147483392 int32
|
||||
16-bit integer PCM -32768 +32767 int16
|
||||
8-bit integer PCM 0 255 uint8
|
||||
===================== =========== =========== =============
|
||||
|
||||
WAV files can specify arbitrary bit depth, and this function supports
|
||||
reading any integer PCM depth from 1 to 64 bits. Data is returned in the
|
||||
smallest compatible numpy int type, in left-justified format. 8-bit and
|
||||
lower is unsigned, while 9-bit and higher is signed.
|
||||
|
||||
For example, 24-bit data will be stored as int32, with the MSB of the
|
||||
24-bit data stored at the MSB of the int32, and typically the least
|
||||
significant byte is 0x00. (However, if a file actually contains data past
|
||||
its specified bit depth, those bits will be read and output, too. [2]_)
|
||||
|
||||
This bit justification and sign matches WAV's native internal format, which
|
||||
allows memory mapping of WAV files that use 1, 2, 4, or 8 bytes per sample
|
||||
(so 24-bit files cannot be memory-mapped, but 32-bit can).
|
||||
|
||||
IEEE float PCM in 32- or 64-bit format is supported, with or without mmap.
|
||||
Values exceeding [-1, +1] are not clipped.
|
||||
|
||||
Non-linear PCM (mu-law, A-law) is not supported.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] IBM Corporation and Microsoft Corporation, "Multimedia Programming
|
||||
Interface and Data Specifications 1.0", section "Data Format of the
|
||||
Samples", August 1991
|
||||
http://www.tactilemedia.com/info/MCI_Control_Info.html
|
||||
.. [2] Adobe Systems Incorporated, "Adobe Audition 3 User Guide", section
|
||||
"Audio file formats: 24-bit Packed Int (type 1, 20-bit)", 2007
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from os.path import dirname, join as pjoin
|
||||
>>> from scipy.io import wavfile
|
||||
>>> import scipy.io
|
||||
|
||||
Get the filename for an example .wav file from the tests/data directory.
|
||||
|
||||
>>> data_dir = pjoin(dirname(scipy.io.__file__), 'tests', 'data')
|
||||
>>> wav_fname = pjoin(data_dir, 'test-44100Hz-2ch-32bit-float-be.wav')
|
||||
|
||||
Load the .wav file contents.
|
||||
|
||||
>>> samplerate, data = wavfile.read(wav_fname)
|
||||
>>> print(f"number of channels = {data.shape[1]}")
|
||||
number of channels = 2
|
||||
>>> length = data.shape[0] / samplerate
|
||||
>>> print(f"length = {length}s")
|
||||
length = 0.01s
|
||||
|
||||
Plot the waveform.
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> import numpy as np
|
||||
>>> time = np.linspace(0., length, data.shape[0])
|
||||
>>> plt.plot(time, data[:, 0], label="Left channel")
|
||||
>>> plt.plot(time, data[:, 1], label="Right channel")
|
||||
>>> plt.legend()
|
||||
>>> plt.xlabel("Time [s]")
|
||||
>>> plt.ylabel("Amplitude")
|
||||
>>> plt.show()
|
||||
|
||||
"""
|
||||
if hasattr(filename, "read"):
|
||||
fid = filename
|
||||
mmap = False
|
||||
else:
|
||||
# pylint: disable=consider-using-with
|
||||
fid = open(filename, "rb")
|
||||
|
||||
try:
|
||||
file_size, is_big_endian = _read_riff_chunk(fid)
|
||||
fmt_chunk_received = False
|
||||
data_chunk_received = False
|
||||
while fid.tell() < file_size:
|
||||
# read the next chunk
|
||||
chunk_id = fid.read(4)
|
||||
|
||||
if not chunk_id:
|
||||
if data_chunk_received:
|
||||
# End of file but data successfully read
|
||||
warnings.warn(
|
||||
f"Reached EOF prematurely; finished at {fid.tell()} bytes, "
|
||||
"expected {file_size} bytes from header.",
|
||||
WavFileWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
break
|
||||
|
||||
raise ValueError("Unexpected end of file.")
|
||||
if len(chunk_id) < 4:
|
||||
msg = f"Incomplete chunk ID: {repr(chunk_id)}"
|
||||
# If we have the data, ignore the broken chunk
|
||||
if fmt_chunk_received and data_chunk_received:
|
||||
warnings.warn(msg + ", ignoring it.", WavFileWarning, stacklevel=2)
|
||||
else:
|
||||
raise ValueError(msg)
|
||||
|
||||
if chunk_id == b"fmt ":
|
||||
fmt_chunk_received = True
|
||||
fmt_chunk = _read_fmt_chunk(fid, is_big_endian)
|
||||
format_tag, channels, fs = fmt_chunk[1:4]
|
||||
bit_depth = fmt_chunk[6]
|
||||
block_align = fmt_chunk[5]
|
||||
elif chunk_id == b"fact":
|
||||
_skip_unknown_chunk(fid, is_big_endian)
|
||||
elif chunk_id == b"data":
|
||||
data_chunk_received = True
|
||||
if not fmt_chunk_received:
|
||||
raise ValueError("No fmt chunk before data")
|
||||
data = _read_data_chunk(
|
||||
fid,
|
||||
format_tag,
|
||||
channels,
|
||||
bit_depth,
|
||||
is_big_endian,
|
||||
block_align,
|
||||
mmap,
|
||||
)
|
||||
elif chunk_id == b"LIST":
|
||||
# Someday this could be handled properly but for now skip it
|
||||
_skip_unknown_chunk(fid, is_big_endian)
|
||||
elif chunk_id in {b"JUNK", b"Fake"}:
|
||||
# Skip alignment chunks without warning
|
||||
_skip_unknown_chunk(fid, is_big_endian)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Chunk (non-data) not understood, skipping it.",
|
||||
WavFileWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
_skip_unknown_chunk(fid, is_big_endian)
|
||||
finally:
|
||||
if not hasattr(filename, "read"):
|
||||
fid.close()
|
||||
else:
|
||||
fid.seek(0)
|
||||
|
||||
return fs, data
|
||||
|
||||
|
||||
def write(filename, rate, data):
|
||||
"""
|
||||
Write a NumPy array as a WAV file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : string or open file handle
|
||||
Output wav file.
|
||||
rate : int
|
||||
The sample rate (in samples/sec).
|
||||
data : ndarray
|
||||
A 1-D or 2-D NumPy array of either integer or float data-type.
|
||||
|
||||
Notes
|
||||
-----
|
||||
* Writes a simple uncompressed WAV file.
|
||||
* To write multiple-channels, use a 2-D array of shape
|
||||
(Nsamples, Nchannels).
|
||||
* The bits-per-sample and PCM/float will be determined by the data-type.
|
||||
|
||||
Common data types: [1]_
|
||||
|
||||
===================== =========== =========== =============
|
||||
WAV format Min Max NumPy dtype
|
||||
===================== =========== =========== =============
|
||||
32-bit floating-point -1.0 +1.0 float32
|
||||
32-bit PCM -2147483648 +2147483647 int32
|
||||
16-bit PCM -32768 +32767 int16
|
||||
8-bit PCM 0 255 uint8
|
||||
===================== =========== =========== =============
|
||||
|
||||
Note that 8-bit PCM is unsigned.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] IBM Corporation and Microsoft Corporation, "Multimedia Programming
|
||||
Interface and Data Specifications 1.0", section "Data Format of the
|
||||
Samples", August 1991
|
||||
http://www.tactilemedia.com/info/MCI_Control_Info.html
|
||||
|
||||
Examples
|
||||
--------
|
||||
Create a 100Hz sine wave, sampled at 44100Hz.
|
||||
Write to 16-bit PCM, Mono.
|
||||
|
||||
>>> from scipy.io.wavfile import write
|
||||
>>> samplerate = 44100; fs = 100
|
||||
>>> t = np.linspace(0., 1., samplerate)
|
||||
>>> amplitude = np.iinfo(np.int16).max
|
||||
>>> data = amplitude * np.sin(2. * np.pi * fs * t)
|
||||
>>> write("example.wav", samplerate, data.astype(np.int16))
|
||||
|
||||
"""
|
||||
if hasattr(filename, "write"):
|
||||
fid = filename
|
||||
else:
|
||||
# pylint: disable=consider-using-with
|
||||
fid = open(filename, "wb")
|
||||
|
||||
fs = rate
|
||||
|
||||
try:
|
||||
dkind = data.dtype.kind
|
||||
if not (
|
||||
dkind == "i" or dkind == "f" or (dkind == "u" and data.dtype.itemsize == 1)
|
||||
):
|
||||
raise ValueError(f"Unsupported data type '{data.dtype}'")
|
||||
|
||||
header_data = b""
|
||||
|
||||
header_data += b"RIFF"
|
||||
header_data += b"\x00\x00\x00\x00"
|
||||
header_data += b"WAVE"
|
||||
|
||||
# fmt chunk
|
||||
header_data += b"fmt "
|
||||
if dkind == "f":
|
||||
format_tag = WAVE_FORMAT.IEEE_FLOAT
|
||||
else:
|
||||
format_tag = WAVE_FORMAT.PCM
|
||||
if data.ndim == 1:
|
||||
channels = 1
|
||||
else:
|
||||
channels = data.shape[1]
|
||||
bit_depth = data.dtype.itemsize * 8
|
||||
bytes_per_second = fs * (bit_depth // 8) * channels
|
||||
block_align = channels * (bit_depth // 8)
|
||||
|
||||
fmt_chunk_data = struct.pack(
|
||||
"<HHIIHH",
|
||||
format_tag,
|
||||
channels,
|
||||
fs,
|
||||
bytes_per_second,
|
||||
block_align,
|
||||
bit_depth,
|
||||
)
|
||||
if not (dkind in ("i", "u")):
|
||||
# add cbSize field for non-PCM files
|
||||
fmt_chunk_data += b"\x00\x00"
|
||||
|
||||
header_data += struct.pack("<I", len(fmt_chunk_data))
|
||||
header_data += fmt_chunk_data
|
||||
|
||||
# fact chunk (non-PCM files)
|
||||
if not (dkind in ("i", "u")):
|
||||
header_data += b"fact"
|
||||
header_data += struct.pack("<II", 4, data.shape[0])
|
||||
|
||||
# check data size (needs to be immediately before the data chunk)
|
||||
if ((len(header_data) - 4 - 4) + (4 + 4 + data.nbytes)) > 0xFFFFFFFF:
|
||||
raise ValueError("Data exceeds wave file size limit")
|
||||
|
||||
fid.write(header_data)
|
||||
|
||||
# data chunk
|
||||
fid.write(b"data")
|
||||
fid.write(struct.pack("<I", data.nbytes))
|
||||
if data.dtype.byteorder == ">" or (
|
||||
data.dtype.byteorder == "=" and sys.byteorder == "big"
|
||||
):
|
||||
data = data.byteswap()
|
||||
_array_tofile(fid, data)
|
||||
|
||||
# Determine file size and place it in correct
|
||||
# position at start of the file.
|
||||
size = fid.tell()
|
||||
fid.seek(4)
|
||||
fid.write(struct.pack("<I", size - 8))
|
||||
|
||||
finally:
|
||||
if not hasattr(filename, "write"):
|
||||
fid.close()
|
||||
else:
|
||||
fid.seek(0)
|
||||
|
||||
|
||||
def _array_tofile(fid, data):
|
||||
# ravel gives a c-contiguous buffer
|
||||
fid.write(data.ravel().view("b").data)
|
@ -0,0 +1,11 @@
|
||||
|
||||
[mypy]
|
||||
|
||||
[mypy-setuptools.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-librosa.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-onnxruntime.*]
|
||||
ignore_missing_imports = True
|
@ -0,0 +1,7 @@
|
||||
cython>=0.29.0,<1
|
||||
espeak-phonemizer>=1.1.0,<2
|
||||
librosa>=0.9.2,<1
|
||||
numpy>=1.19.0
|
||||
onnxruntime~=1.11.0
|
||||
pytorch-lightning~=1.7.0
|
||||
torch~=1.11.0
|
@ -0,0 +1,7 @@
|
||||
black==22.3.0
|
||||
coverage==5.0.4
|
||||
flake8==3.7.9
|
||||
mypy==0.910
|
||||
pylint==2.10.2
|
||||
pytest==5.4.1
|
||||
pytest-cov==2.8.1
|
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Runs formatters, linters, and type checkers on Python code.
|
||||
|
||||
set -eo pipefail
|
||||
|
||||
# Directory of *this* script
|
||||
this_dir="$( cd "$( dirname "$0" )" && pwd )"
|
||||
|
||||
base_dir="$(realpath "${this_dir}/..")"
|
||||
|
||||
# Path to virtual environment
|
||||
: "${venv:=${base_dir}/.venv}"
|
||||
|
||||
if [ -d "${venv}" ]; then
|
||||
# Activate virtual environment if available
|
||||
source "${venv}/bin/activate"
|
||||
fi
|
||||
|
||||
python_files=("${base_dir}/larynx_train")
|
||||
|
||||
# Format code
|
||||
black "${python_files[@]}"
|
||||
isort "${python_files[@]}"
|
||||
|
||||
# Check
|
||||
flake8 "${python_files[@]}"
|
||||
pylint "${python_files[@]}"
|
||||
mypy "${python_files[@]}"
|
@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import setuptools
|
||||
from setuptools import setup
|
||||
|
||||
this_dir = Path(__file__).parent
|
||||
module_dir = this_dir / "larynx_train"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Load README in as long description
|
||||
long_description: str = ""
|
||||
readme_path = this_dir / "README.md"
|
||||
if readme_path.is_file():
|
||||
long_description = readme_path.read_text(encoding="utf-8")
|
||||
|
||||
requirements = []
|
||||
requirements_path = this_dir / "requirements.txt"
|
||||
if requirements_path.is_file():
|
||||
with open(requirements_path, "r", encoding="utf-8") as requirements_file:
|
||||
requirements = requirements_file.read().splitlines()
|
||||
|
||||
version_path = module_dir / "VERSION"
|
||||
with open(version_path, "r", encoding="utf-8") as version_file:
|
||||
version = version_file.read().strip()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
setup(
|
||||
name="larynx_train",
|
||||
version=version,
|
||||
description="A fast and local neural text to speech system",
|
||||
long_description=long_description,
|
||||
url="http://github.com/rhasspy/larynx",
|
||||
author="Michael Hansen",
|
||||
author_email="mike@rhasspy.org",
|
||||
license="MIT",
|
||||
packages=setuptools.find_packages(),
|
||||
package_data={
|
||||
"larynx_train": ["VERSION", "py.typed"],
|
||||
},
|
||||
install_requires=requirements,
|
||||
extras_require={':python_version<"3.9"': ["importlib_resources"]},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"larynx-train = larynx_train.__main__:main",
|
||||
]
|
||||
},
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Text Processing :: Linguistic",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
],
|
||||
keywords="rhasspy tts speech voice",
|
||||
)
|
Loading…
Reference in New Issue