Add --max-phoneme-ids

pull/2/head
Michael Hansen 2 years ago
parent 11b294a461
commit 5f704228f6

@ -67,15 +67,19 @@ class LarynxDataset(Dataset):
def __init__(
self,
dataset_paths: List[Union[str, Path]], # settings: LarynxDatasetSettings
dataset_paths: List[Union[str, Path]],
max_phoneme_ids: Optional[int] = None,
):
# self.settings = settings
self.utterances: List[Utterance] = []
for dataset_path in dataset_paths:
dataset_path = Path(dataset_path)
_LOGGER.debug("Loading dataset: %s", dataset_path)
self.utterances.extend(LarynxDataset.load_dataset(dataset_path))
self.utterances.extend(
LarynxDataset.load_dataset(
dataset_path, max_phoneme_ids=max_phoneme_ids
)
)
def __len__(self):
return len(self.utterances)
@ -93,7 +97,11 @@ class LarynxDataset(Dataset):
)
@staticmethod
def load_dataset(dataset_path: Path) -> Iterable[Utterance]:
def load_dataset(
dataset_path: Path, max_phoneme_ids: Optional[int] = None,
) -> Iterable[Utterance]:
num_skipped = 0
with open(dataset_path, "r", encoding="utf-8") as dataset_file:
for line_idx, line in enumerate(dataset_file):
line = line.strip()
@ -101,15 +109,21 @@ class LarynxDataset(Dataset):
continue
try:
yield LarynxDataset.load_utterance(line)
utt = LarynxDataset.load_utterance(line)
if (max_phoneme_ids is None) or (
len(utt.phoneme_ids) <= max_phoneme_ids
):
yield utt
else:
num_skipped += 1
except Exception:
_LOGGER.exception(
"Error on line %s of %s: %s",
line_idx + 1,
dataset_path,
line,
"Error on line %s of %s: %s", line_idx + 1, dataset_path, line,
)
if num_skipped > 0:
_LOGGER.warning("Skipped %s utterance(s)", num_skipped)
@staticmethod
def load_utterance(line: str) -> Utterance:
utt_dict = json.loads(line)

@ -25,11 +25,7 @@ class VitsModel(pl.LightningModule):
# audio
resblock="2",
resblock_kernel_sizes=(3, 5, 7),
resblock_dilation_sizes=(
(1, 2),
(2, 6),
(3, 12),
),
resblock_dilation_sizes=((1, 2), (2, 6), (3, 12),),
upsample_rates=(8, 8, 4),
upsample_initial_channel=256,
upsample_kernel_sizes=(16, 16, 8),
@ -72,7 +68,8 @@ class VitsModel(pl.LightningModule):
seed: int = 1234,
num_test_examples: int = 5,
validation_split: float = 0.1,
**kwargs
max_phoneme_ids: Optional[int] = None,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
@ -111,14 +108,21 @@ class VitsModel(pl.LightningModule):
self._train_dataset: Optional[Dataset] = None
self._val_dataset: Optional[Dataset] = None
self._test_dataset: Optional[Dataset] = None
self._load_datasets(validation_split, num_test_examples)
self._load_datasets(validation_split, num_test_examples, max_phoneme_ids)
# State kept between training optimizers
self._y = None
self._y_hat = None
def _load_datasets(self, validation_split: float, num_test_examples: int):
full_dataset = LarynxDataset(self.hparams.dataset)
def _load_datasets(
self,
validation_split: float,
num_test_examples: int,
max_phoneme_ids: Optional[int] = None,
):
full_dataset = LarynxDataset(
self.hparams.dataset, max_phoneme_ids=max_phoneme_ids
)
valid_set_size = int(len(full_dataset) * validation_split)
train_set_size = len(full_dataset) - valid_set_size - num_test_examples
@ -211,9 +215,7 @@ class VitsModel(pl.LightningModule):
self.hparams.mel_fmax,
)
y_mel = slice_segments(
mel,
ids_slice,
self.hparams.segment_size // self.hparams.hop_length,
mel, ids_slice, self.hparams.segment_size // self.hparams.hop_length,
)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
@ -226,9 +228,7 @@ class VitsModel(pl.LightningModule):
self.hparams.mel_fmax,
)
y = slice_segments(
y,
ids_slice * self.hparams.hop_length,
self.hparams.segment_size,
y, ids_slice * self.hparams.hop_length, self.hparams.segment_size,
) # slice
# Save for training_step_d
@ -320,6 +320,11 @@ class VitsModel(pl.LightningModule):
parser.add_argument("--batch-size", type=int, required=True)
parser.add_argument("--validation-split", type=float, default=0.1)
parser.add_argument("--num-test-examples", type=int, default=5)
parser.add_argument(
"--max-phoneme-ids",
type=int,
help="Exclude utterances with phoneme id lists longer than this",
)
#
parser.add_argument("--hidden-channels", type=int, default=192)
parser.add_argument("--inter-channels", type=int, default=192)

Loading…
Cancel
Save