diff --git a/src/python/larynx_train/export_onnx.py b/src/python/larynx_train/export_onnx.py index 9f7b057..d4e9573 100644 --- a/src/python/larynx_train/export_onnx.py +++ b/src/python/larynx_train/export_onnx.py @@ -8,16 +8,16 @@ import torch from .vits.lightning import VitsModel -_LOGGER = logging.getLogger("mimic3_train.export_onnx") +_LOGGER = logging.getLogger("larynx_train.export_onnx") OPSET_VERSION = 15 def main(): """Main entry point""" - torch.manual_seed(12345) + torch.manual_seed(1234) - parser = argparse.ArgumentParser(prog="mimic3_train.export_onnx") + parser = argparse.ArgumentParser() parser.add_argument("checkpoint", help="Path to model checkpoint (.ckpt)") parser.add_argument("output", help="Path to output model (.onnx)") @@ -70,7 +70,10 @@ def main(): model_g.forward = infer_forward - sequences = torch.randint(low=0, high=num_symbols, size=(1, 50), dtype=torch.long) + dummy_input_length = 50 + sequences = torch.randint( + low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long + ) sequence_lengths = torch.LongTensor([sequences.size(1)]) sid: Optional[int] = None @@ -79,7 +82,6 @@ def main(): # noise, noise_w, length scales = torch.FloatTensor([0.667, 1.0, 0.8]) - dummy_input = (sequences, sequence_lengths, scales, sid) # Export