#!/usr/bin/env python3 import argparse import logging from pathlib import Path from typing import Optional import torch from .vits.lightning import VitsModel _LOGGER = logging.getLogger("mimic3_train.export_onnx") OPSET_VERSION = 15 def main(): """Main entry point""" torch.manual_seed(12345) parser = argparse.ArgumentParser(prog="mimic3_train.export_onnx") parser.add_argument("checkpoint", help="Path to model checkpoint (.ckpt)") parser.add_argument("output", help="Path to output model (.onnx)") parser.add_argument( "--debug", action="store_true", help="Print DEBUG messages to the console" ) args = parser.parse_args() if args.debug: logging.basicConfig(level=logging.DEBUG) else: logging.basicConfig(level=logging.INFO) _LOGGER.debug(args) # ------------------------------------------------------------------------- args.checkpoint = Path(args.checkpoint) args.output = Path(args.output) args.output.parent.mkdir(parents=True, exist_ok=True) model = VitsModel.load_from_checkpoint(args.checkpoint) model_g = model.model_g num_symbols = model_g.n_vocab num_speakers = model_g.n_speakers # Inference only model_g.eval() with torch.no_grad(): model_g.dec.remove_weight_norm() # old_forward = model_g.infer def infer_forward(text, text_lengths, scales, sid=None): noise_scale = scales[0] length_scale = scales[1] noise_scale_w = scales[2] audio = model_g.infer( text, text_lengths, noise_scale=noise_scale, length_scale=length_scale, noise_scale_w=noise_scale_w, sid=sid, )[0].unsqueeze(1) return audio model_g.forward = infer_forward sequences = torch.randint(low=0, high=num_symbols, size=(1, 50), dtype=torch.long) sequence_lengths = torch.LongTensor([sequences.size(1)]) sid: Optional[int] = None if num_speakers > 1: sid = torch.LongTensor([0]) # noise, noise_w, length scales = torch.FloatTensor([0.667, 1.0, 0.8]) dummy_input = (sequences, sequence_lengths, scales, sid) # Export torch.onnx.export( model=model_g, args=dummy_input, f=str(args.output), verbose=True, opset_version=OPSET_VERSION, input_names=["input", "input_lengths", "scales", "sid"], output_names=["output"], dynamic_axes={ "input": {0: "batch_size", 1: "phonemes"}, "input_lengths": {0: "batch_size"}, "output": {0: "batch_size", 1: "time"}, }, ) _LOGGER.info("Exported model to %s", args.output) # ----------------------------------------------------------------------------- if __name__ == "__main__": main()