Exporting to TorchScript

This commit is contained in:
Michael Hansen
2023-01-08 21:38:54 -06:00
parent f1cc4e58bd
commit e5062c9496
5 changed files with 115 additions and 6 deletions

View File

@@ -51,7 +51,26 @@ def main():
model_g.forward = model_g.infer
jitted_model = torch.jit.script(model_g)
dummy_input_length = 50
sequences = torch.randint(
low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long
)
sequence_lengths = torch.LongTensor([sequences.size(1)])
sid: Optional[int] = None
if num_speakers > 1:
sid = torch.LongTensor([0])
dummy_input = (
sequences,
sequence_lengths,
sid,
torch.FloatTensor([0.667]),
torch.FloatTensor([1.0]),
torch.FloatTensor([0.8]),
)
jitted_model = torch.jit.trace(model_g, dummy_input)
torch.jit.save(jitted_model, str(args.output))
_LOGGER.info("Saved TorchScript model to %s", args.output)