diff --git a/src/python/larynx_train/infer.py b/src/python/larynx_train/infer.py index 924ad3b..993a305 100644 --- a/src/python/larynx_train/infer.py +++ b/src/python/larynx_train/infer.py @@ -43,16 +43,17 @@ def main(): continue utt = json.loads(line) - # utt_id = utt["id"] utt_id = str(i) phoneme_ids = utt["phoneme_ids"] + speaker_id = utt.get("speaker_id") text = torch.LongTensor(phoneme_ids).unsqueeze(0) text_lengths = torch.LongTensor([len(phoneme_ids)]) scales = [0.667, 1.0, 0.8] + sid = torch.LongTensor([speaker_id]) if speaker_id is not None else None start_time = time.perf_counter() - audio = model(text, text_lengths, scales).detach().numpy() + audio = model(text, text_lengths, scales, sid=sid).detach().numpy() audio = audio_float_to_int16(audio) end_time = time.perf_counter() diff --git a/src/python/larynx_train/preprocess.py b/src/python/larynx_train/preprocess.py index be625f9..3568972 100644 --- a/src/python/larynx_train/preprocess.py +++ b/src/python/larynx_train/preprocess.py @@ -147,6 +147,9 @@ def main(): for _ in range(num_utterances): utt = queue_out.get() if utt is not None: + if utt.speaker is not None: + utt.speaker_id = speaker_ids[utt.speaker] + # JSONL json.dump( dataclasses.asdict(utt), @@ -207,6 +210,7 @@ class Utterance: text: str audio_path: Path speaker: Optional[str] = None + speaker_id: Optional[int] = None phonemes: Optional[List[str]] = None phoneme_ids: Optional[List[int]] = None audio_norm_path: Optional[Path] = None