Pass speaker id during verification

pull/2/head
Michael Hansen 1 year ago
parent c38020cb14
commit 06a154a4ed

@ -98,7 +98,8 @@ class LarynxDataset(Dataset):
@staticmethod
def load_dataset(
dataset_path: Path, max_phoneme_ids: Optional[int] = None,
dataset_path: Path,
max_phoneme_ids: Optional[int] = None,
) -> Iterable[Utterance]:
num_skipped = 0
@ -118,7 +119,10 @@ class LarynxDataset(Dataset):
num_skipped += 1
except Exception:
_LOGGER.exception(
"Error on line %s of %s: %s", line_idx + 1, dataset_path, line,
"Error on line %s of %s: %s",
line_idx + 1,
dataset_path,
line,
)
if num_skipped > 0:

@ -25,7 +25,11 @@ class VitsModel(pl.LightningModule):
# audio
resblock="2",
resblock_kernel_sizes=(3, 5, 7),
resblock_dilation_sizes=((1, 2), (2, 6), (3, 12),),
resblock_dilation_sizes=(
(1, 2),
(2, 6),
(3, 12),
),
upsample_rates=(8, 8, 4),
upsample_initial_channel=256,
upsample_kernel_sizes=(16, 16, 8),
@ -215,7 +219,9 @@ class VitsModel(pl.LightningModule):
self.hparams.mel_fmax,
)
y_mel = slice_segments(
mel, ids_slice, self.hparams.segment_size // self.hparams.hop_length,
mel,
ids_slice,
self.hparams.segment_size // self.hparams.hop_length,
)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
@ -228,7 +234,9 @@ class VitsModel(pl.LightningModule):
self.hparams.mel_fmax,
)
y = slice_segments(
y, ids_slice * self.hparams.hop_length, self.hparams.segment_size,
y,
ids_slice * self.hparams.hop_length,
self.hparams.segment_size,
) # slice
# Save for training_step_d
@ -276,7 +284,12 @@ class VitsModel(pl.LightningModule):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [0.667, 1.0, 0.8]
test_audio = self(text, text_lengths, scales).detach()
sid = (
test_utt.speaker_id.to(self.device)
if test_utt.speaker_id is not None
else None
)
test_audio = self(text, text_lengths, scales, sid=sid).detach()
# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))

@ -686,6 +686,7 @@ class SynthesizerTrn(nn.Module):
):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 1:
assert sid is not None, "Missing speaker id"
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None

Loading…
Cancel
Save