You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
piper/src/python/piper_train/preprocess.py

318 lines
10 KiB
Python

#!/usr/bin/env python3
import argparse
import csv
import dataclasses
import itertools
import json
import logging
import os
from collections import Counter
from dataclasses import dataclass
from multiprocessing import JoinableQueue, Process, Queue
from pathlib import Path
from typing import Dict, Iterable, List, Optional
from espeak_phonemizer import Phonemizer
from .norm_audio import cache_norm_audio, make_silence_detector
from .phonemize import DEFAULT_PHONEME_ID_MAP, phonemes_to_ids, phonemize
_LOGGER = logging.getLogger("preprocess")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-dir", required=True, help="Directory with audio dataset"
)
parser.add_argument(
"--output-dir",
required=True,
help="Directory to write output files for training",
)
parser.add_argument("--language", required=True, help="eSpeak-ng voice")
parser.add_argument(
"--sample-rate",
type=int,
required=True,
help="Target sample rate for voice (hertz)",
)
parser.add_argument(
"--dataset-format", choices=("ljspeech", "mycroft"), required=True
)
parser.add_argument("--cache-dir", help="Directory to cache processed audio files")
parser.add_argument("--max-workers", type=int)
parser.add_argument(
"--single-speaker", action="store_true", help="Force single speaker dataset"
)
parser.add_argument(
"--speaker-id", type=int, help="Add speaker id to single speaker dataset"
)
parser.add_argument(
"--debug", action="store_true", help="Print DEBUG messages to the console"
)
args = parser.parse_args()
if args.single_speaker and (args.speaker_id is not None):
_LOGGER.fatal("--single-speaker and --speaker-id cannot both be provided")
return
level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(level=level)
logging.getLogger().setLevel(level)
# Prevent log spam
logging.getLogger("numba").setLevel(logging.WARNING)
# Convert to paths and create output directories
args.input_dir = Path(args.input_dir)
args.output_dir = Path(args.output_dir)
args.output_dir.mkdir(parents=True, exist_ok=True)
args.cache_dir = (
Path(args.cache_dir)
if args.cache_dir
else args.output_dir / "cache" / str(args.sample_rate)
)
args.cache_dir.mkdir(parents=True, exist_ok=True)
if args.dataset_format == "mycroft":
make_dataset = mycroft_dataset
else:
make_dataset = ljspeech_dataset
# Count speakers
_LOGGER.debug("Counting number of speakers/utterances in the dataset")
speaker_counts: Counter[str] = Counter()
num_utterances = 0
for utt in make_dataset(args.input_dir, args.single_speaker, args.speaker_id):
speaker = utt.speaker or ""
speaker_counts[speaker] += 1
num_utterances += 1
assert num_utterances > 0, "No utterances found"
is_multispeaker = len(speaker_counts) > 1
speaker_ids: Dict[str, int] = {}
if is_multispeaker:
_LOGGER.info("%s speakers detected", len(speaker_counts))
# Assign speaker ids by most number of utterances first
for speaker_id, (speaker, _speaker_count) in enumerate(
speaker_counts.most_common()
):
speaker_ids[speaker] = speaker_id
else:
_LOGGER.info("Single speaker dataset")
# Write config
with open(args.output_dir / "config.json", "w", encoding="utf-8") as config_file:
json.dump(
{
"audio": {
"sample_rate": args.sample_rate,
},
"espeak": {
"voice": args.language,
},
"inference": {"noise_scale": 0.667, "length_scale": 1, "noise_w": 0.8},
"phoneme_map": {},
"phoneme_id_map": DEFAULT_PHONEME_ID_MAP,
"num_symbols": len(
set(itertools.chain.from_iterable(DEFAULT_PHONEME_ID_MAP.values()))
),
"num_speakers": len(speaker_counts),
"speaker_id_map": speaker_ids,
},
config_file,
ensure_ascii=False,
indent=4,
)
_LOGGER.info("Wrote dataset config")
if (args.max_workers is None) or (args.max_workers < 1):
args.max_workers = os.cpu_count()
assert args.max_workers is not None
batch_size = int(num_utterances / (args.max_workers * 2))
queue_in: "Queue[Iterable[Utterance]]" = JoinableQueue()
queue_out: "Queue[Optional[Utterance]]" = Queue()
# Start workers
processes = [
Process(target=process_batch, args=(args, queue_in, queue_out))
for _ in range(args.max_workers)
]
for proc in processes:
proc.start()
_LOGGER.info(
"Processing %s utterance(s) with %s worker(s)", num_utterances, args.max_workers
)
with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
for utt_batch in batched(
make_dataset(args.input_dir, args.single_speaker, args.speaker_id),
batch_size,
):
queue_in.put(utt_batch)
_LOGGER.debug("Waiting for jobs to finish")
for _ in range(num_utterances):
utt = queue_out.get()
if utt is not None:
if utt.speaker is not None:
utt.speaker_id = speaker_ids[utt.speaker]
# JSONL
json.dump(
dataclasses.asdict(utt),
dataset_file,
ensure_ascii=False,
cls=PathEncoder,
)
print("", file=dataset_file)
# Signal workers to stop
for proc in processes:
queue_in.put(None)
# Wait for workers to stop
for proc in processes:
proc.join(timeout=1)
# -----------------------------------------------------------------------------
def process_batch(args: argparse.Namespace, queue_in: JoinableQueue, queue_out: Queue):
try:
silence_detector = make_silence_detector()
phonemizer = Phonemizer(default_voice=args.language)
while True:
utt_batch = queue_in.get()
if utt_batch is None:
break
for utt in utt_batch:
try:
_LOGGER.debug(utt)
utt.phonemes = phonemize(utt.text, phonemizer)
utt.phoneme_ids = phonemes_to_ids(utt.phonemes)
utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
utt.audio_path,
args.cache_dir,
silence_detector,
args.sample_rate,
)
queue_out.put(utt)
except Exception:
_LOGGER.exception("Failed to process utterance: %s", utt)
queue_out.put(None)
queue_in.task_done()
except Exception:
_LOGGER.exception("process_batch")
# -----------------------------------------------------------------------------
@dataclass
class Utterance:
text: str
audio_path: Path
speaker: Optional[str] = None
speaker_id: Optional[int] = None
phonemes: Optional[List[str]] = None
phoneme_ids: Optional[List[int]] = None
audio_norm_path: Optional[Path] = None
audio_spec_path: Optional[Path] = None
class PathEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Path):
return str(o)
return super().default(o)
def ljspeech_dataset(
dataset_dir: Path, is_single_speaker: bool, speaker_id: Optional[int] = None
) -> Iterable[Utterance]:
# filename|speaker|text
# speaker is optional
metadata_path = dataset_dir / "metadata.csv"
assert metadata_path.exists(), f"Missing {metadata_path}"
wav_dir = dataset_dir / "wav"
if not wav_dir.is_dir():
wav_dir = dataset_dir / "wavs"
with open(metadata_path, "r", encoding="utf-8") as csv_file:
reader = csv.reader(csv_file, delimiter="|")
for row in reader:
assert len(row) >= 2, "Not enough colums"
speaker: Optional[str] = None
if is_single_speaker or (len(row) == 2):
filename, text = row[0], row[-1]
else:
filename, speaker, text = row[0], row[1], row[-1]
# Try file name relative to metadata
wav_path = metadata_path.parent / filename
if not wav_path.exists():
# Try with .wav
wav_path = metadata_path.parent / f"{filename}.wav"
if not wav_path.exists():
# Try wav/ or wavs/
wav_path = wav_dir / filename
if not wav_path.exists():
# Try with .wav
wav_path = wav_dir / f"{filename}.wav"
if not wav_path.exists():
_LOGGER.warning("Missing %s", filename)
continue
yield Utterance(
text=text, audio_path=wav_path, speaker=speaker, speaker_id=speaker_id
)
def mycroft_dataset(
dataset_dir: Path, is_single_speaker: bool, speaker_id: Optional[int] = None
) -> Iterable[Utterance]:
for info_path in dataset_dir.glob("*.info"):
wav_path = info_path.with_suffix(".wav")
if wav_path.exists():
text = info_path.read_text(encoding="utf-8").strip()
yield Utterance(text=text, audio_path=wav_path, speaker_id=speaker_id)
# -----------------------------------------------------------------------------
def batched(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
batch = list(itertools.islice(it, n))
while batch:
yield batch
batch = list(itertools.islice(it, n))
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main()