mirror of
https://github.com/pstrueb/piper.git
synced 2026-04-18 22:34:49 +00:00
Add --max-phoneme-ids
This commit is contained in:
@@ -67,15 +67,19 @@ class LarynxDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_paths: List[Union[str, Path]], # settings: LarynxDatasetSettings
|
||||
dataset_paths: List[Union[str, Path]],
|
||||
max_phoneme_ids: Optional[int] = None,
|
||||
):
|
||||
# self.settings = settings
|
||||
self.utterances: List[Utterance] = []
|
||||
|
||||
for dataset_path in dataset_paths:
|
||||
dataset_path = Path(dataset_path)
|
||||
_LOGGER.debug("Loading dataset: %s", dataset_path)
|
||||
self.utterances.extend(LarynxDataset.load_dataset(dataset_path))
|
||||
self.utterances.extend(
|
||||
LarynxDataset.load_dataset(
|
||||
dataset_path, max_phoneme_ids=max_phoneme_ids
|
||||
)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.utterances)
|
||||
@@ -93,7 +97,11 @@ class LarynxDataset(Dataset):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_dataset(dataset_path: Path) -> Iterable[Utterance]:
|
||||
def load_dataset(
|
||||
dataset_path: Path, max_phoneme_ids: Optional[int] = None,
|
||||
) -> Iterable[Utterance]:
|
||||
num_skipped = 0
|
||||
|
||||
with open(dataset_path, "r", encoding="utf-8") as dataset_file:
|
||||
for line_idx, line in enumerate(dataset_file):
|
||||
line = line.strip()
|
||||
@@ -101,15 +109,21 @@ class LarynxDataset(Dataset):
|
||||
continue
|
||||
|
||||
try:
|
||||
yield LarynxDataset.load_utterance(line)
|
||||
utt = LarynxDataset.load_utterance(line)
|
||||
if (max_phoneme_ids is None) or (
|
||||
len(utt.phoneme_ids) <= max_phoneme_ids
|
||||
):
|
||||
yield utt
|
||||
else:
|
||||
num_skipped += 1
|
||||
except Exception:
|
||||
_LOGGER.exception(
|
||||
"Error on line %s of %s: %s",
|
||||
line_idx + 1,
|
||||
dataset_path,
|
||||
line,
|
||||
"Error on line %s of %s: %s", line_idx + 1, dataset_path, line,
|
||||
)
|
||||
|
||||
if num_skipped > 0:
|
||||
_LOGGER.warning("Skipped %s utterance(s)", num_skipped)
|
||||
|
||||
@staticmethod
|
||||
def load_utterance(line: str) -> Utterance:
|
||||
utt_dict = json.loads(line)
|
||||
|
||||
@@ -25,11 +25,7 @@ class VitsModel(pl.LightningModule):
|
||||
# audio
|
||||
resblock="2",
|
||||
resblock_kernel_sizes=(3, 5, 7),
|
||||
resblock_dilation_sizes=(
|
||||
(1, 2),
|
||||
(2, 6),
|
||||
(3, 12),
|
||||
),
|
||||
resblock_dilation_sizes=((1, 2), (2, 6), (3, 12),),
|
||||
upsample_rates=(8, 8, 4),
|
||||
upsample_initial_channel=256,
|
||||
upsample_kernel_sizes=(16, 16, 8),
|
||||
@@ -72,7 +68,8 @@ class VitsModel(pl.LightningModule):
|
||||
seed: int = 1234,
|
||||
num_test_examples: int = 5,
|
||||
validation_split: float = 0.1,
|
||||
**kwargs
|
||||
max_phoneme_ids: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
@@ -111,14 +108,21 @@ class VitsModel(pl.LightningModule):
|
||||
self._train_dataset: Optional[Dataset] = None
|
||||
self._val_dataset: Optional[Dataset] = None
|
||||
self._test_dataset: Optional[Dataset] = None
|
||||
self._load_datasets(validation_split, num_test_examples)
|
||||
self._load_datasets(validation_split, num_test_examples, max_phoneme_ids)
|
||||
|
||||
# State kept between training optimizers
|
||||
self._y = None
|
||||
self._y_hat = None
|
||||
|
||||
def _load_datasets(self, validation_split: float, num_test_examples: int):
|
||||
full_dataset = LarynxDataset(self.hparams.dataset)
|
||||
def _load_datasets(
|
||||
self,
|
||||
validation_split: float,
|
||||
num_test_examples: int,
|
||||
max_phoneme_ids: Optional[int] = None,
|
||||
):
|
||||
full_dataset = LarynxDataset(
|
||||
self.hparams.dataset, max_phoneme_ids=max_phoneme_ids
|
||||
)
|
||||
valid_set_size = int(len(full_dataset) * validation_split)
|
||||
train_set_size = len(full_dataset) - valid_set_size - num_test_examples
|
||||
|
||||
@@ -211,9 +215,7 @@ class VitsModel(pl.LightningModule):
|
||||
self.hparams.mel_fmax,
|
||||
)
|
||||
y_mel = slice_segments(
|
||||
mel,
|
||||
ids_slice,
|
||||
self.hparams.segment_size // self.hparams.hop_length,
|
||||
mel, ids_slice, self.hparams.segment_size // self.hparams.hop_length,
|
||||
)
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.squeeze(1),
|
||||
@@ -226,9 +228,7 @@ class VitsModel(pl.LightningModule):
|
||||
self.hparams.mel_fmax,
|
||||
)
|
||||
y = slice_segments(
|
||||
y,
|
||||
ids_slice * self.hparams.hop_length,
|
||||
self.hparams.segment_size,
|
||||
y, ids_slice * self.hparams.hop_length, self.hparams.segment_size,
|
||||
) # slice
|
||||
|
||||
# Save for training_step_d
|
||||
@@ -320,6 +320,11 @@ class VitsModel(pl.LightningModule):
|
||||
parser.add_argument("--batch-size", type=int, required=True)
|
||||
parser.add_argument("--validation-split", type=float, default=0.1)
|
||||
parser.add_argument("--num-test-examples", type=int, default=5)
|
||||
parser.add_argument(
|
||||
"--max-phoneme-ids",
|
||||
type=int,
|
||||
help="Exclude utterances with phoneme id lists longer than this",
|
||||
)
|
||||
#
|
||||
parser.add_argument("--hidden-channels", type=int, default=192)
|
||||
parser.add_argument("--inter-channels", type=int, default=192)
|
||||
|
||||
Reference in New Issue
Block a user