|
|
|
@ -25,11 +25,7 @@ class VitsModel(pl.LightningModule):
|
|
|
|
|
# audio
|
|
|
|
|
resblock="2",
|
|
|
|
|
resblock_kernel_sizes=(3, 5, 7),
|
|
|
|
|
resblock_dilation_sizes=(
|
|
|
|
|
(1, 2),
|
|
|
|
|
(2, 6),
|
|
|
|
|
(3, 12),
|
|
|
|
|
),
|
|
|
|
|
resblock_dilation_sizes=((1, 2), (2, 6), (3, 12),),
|
|
|
|
|
upsample_rates=(8, 8, 4),
|
|
|
|
|
upsample_initial_channel=256,
|
|
|
|
|
upsample_kernel_sizes=(16, 16, 8),
|
|
|
|
@ -72,7 +68,8 @@ class VitsModel(pl.LightningModule):
|
|
|
|
|
seed: int = 1234,
|
|
|
|
|
num_test_examples: int = 5,
|
|
|
|
|
validation_split: float = 0.1,
|
|
|
|
|
**kwargs
|
|
|
|
|
max_phoneme_ids: Optional[int] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.save_hyperparameters()
|
|
|
|
@ -111,14 +108,21 @@ class VitsModel(pl.LightningModule):
|
|
|
|
|
self._train_dataset: Optional[Dataset] = None
|
|
|
|
|
self._val_dataset: Optional[Dataset] = None
|
|
|
|
|
self._test_dataset: Optional[Dataset] = None
|
|
|
|
|
self._load_datasets(validation_split, num_test_examples)
|
|
|
|
|
self._load_datasets(validation_split, num_test_examples, max_phoneme_ids)
|
|
|
|
|
|
|
|
|
|
# State kept between training optimizers
|
|
|
|
|
self._y = None
|
|
|
|
|
self._y_hat = None
|
|
|
|
|
|
|
|
|
|
def _load_datasets(self, validation_split: float, num_test_examples: int):
|
|
|
|
|
full_dataset = LarynxDataset(self.hparams.dataset)
|
|
|
|
|
def _load_datasets(
|
|
|
|
|
self,
|
|
|
|
|
validation_split: float,
|
|
|
|
|
num_test_examples: int,
|
|
|
|
|
max_phoneme_ids: Optional[int] = None,
|
|
|
|
|
):
|
|
|
|
|
full_dataset = LarynxDataset(
|
|
|
|
|
self.hparams.dataset, max_phoneme_ids=max_phoneme_ids
|
|
|
|
|
)
|
|
|
|
|
valid_set_size = int(len(full_dataset) * validation_split)
|
|
|
|
|
train_set_size = len(full_dataset) - valid_set_size - num_test_examples
|
|
|
|
|
|
|
|
|
@ -211,9 +215,7 @@ class VitsModel(pl.LightningModule):
|
|
|
|
|
self.hparams.mel_fmax,
|
|
|
|
|
)
|
|
|
|
|
y_mel = slice_segments(
|
|
|
|
|
mel,
|
|
|
|
|
ids_slice,
|
|
|
|
|
self.hparams.segment_size // self.hparams.hop_length,
|
|
|
|
|
mel, ids_slice, self.hparams.segment_size // self.hparams.hop_length,
|
|
|
|
|
)
|
|
|
|
|
y_hat_mel = mel_spectrogram_torch(
|
|
|
|
|
y_hat.squeeze(1),
|
|
|
|
@ -226,9 +228,7 @@ class VitsModel(pl.LightningModule):
|
|
|
|
|
self.hparams.mel_fmax,
|
|
|
|
|
)
|
|
|
|
|
y = slice_segments(
|
|
|
|
|
y,
|
|
|
|
|
ids_slice * self.hparams.hop_length,
|
|
|
|
|
self.hparams.segment_size,
|
|
|
|
|
y, ids_slice * self.hparams.hop_length, self.hparams.segment_size,
|
|
|
|
|
) # slice
|
|
|
|
|
|
|
|
|
|
# Save for training_step_d
|
|
|
|
@ -320,6 +320,11 @@ class VitsModel(pl.LightningModule):
|
|
|
|
|
parser.add_argument("--batch-size", type=int, required=True)
|
|
|
|
|
parser.add_argument("--validation-split", type=float, default=0.1)
|
|
|
|
|
parser.add_argument("--num-test-examples", type=int, default=5)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--max-phoneme-ids",
|
|
|
|
|
type=int,
|
|
|
|
|
help="Exclude utterances with phoneme id lists longer than this",
|
|
|
|
|
)
|
|
|
|
|
#
|
|
|
|
|
parser.add_argument("--hidden-channels", type=int, default=192)
|
|
|
|
|
parser.add_argument("--inter-channels", type=int, default=192)
|
|
|
|
|