Skip to content

Commit 52d2e6b

Browse files
authored
Merge pull request #529 from idiap/simplify
Fix TTS/bin scripts, small refactors
2 parents 5548ef1 + fdc061e commit 52d2e6b

File tree

20 files changed

+301
-273
lines changed

20 files changed

+301
-273
lines changed

TTS/bin/compute_embeddings.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,19 @@ def parse_args(arg_list: list[str] | None) -> argparse.Namespace:
9494
help="Path to the evaluation meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`",
9595
default=None,
9696
)
97-
return parser.parse_args()
97+
return parser.parse_args(arg_list)
9898

9999

100100
def compute_embeddings(
101101
model_path,
102102
config_path,
103103
output_path,
104-
old_speakers_file=None,
104+
old_speakers_file: str | None = None,
105105
old_append=False,
106106
config_dataset_path=None,
107-
formatter_name=None,
108-
dataset_name=None,
109-
dataset_path=None,
107+
formatter_name: str | None = None,
108+
dataset_name: str | None = None,
109+
dataset_path: str | None = None,
110110
meta_file_train=None,
111111
meta_file_val=None,
112112
disable_cuda=False,
@@ -128,11 +128,7 @@ def compute_embeddings(
128128
c_dataset.meta_file_val = meta_file_val
129129
meta_data_train, meta_data_eval = load_tts_samples(c_dataset, eval_split=not no_eval)
130130

131-
if meta_data_eval is None:
132-
samples = meta_data_train
133-
else:
134-
samples = meta_data_train + meta_data_eval
135-
131+
samples = meta_data_train + meta_data_eval
136132
encoder_manager = SpeakerManager(
137133
encoder_model_path=model_path,
138134
encoder_config_path=config_path,
@@ -182,6 +178,7 @@ def compute_embeddings(
182178

183179
save_file(speaker_mapping, mapping_file_path)
184180
print("Speaker embeddings saved at:", mapping_file_path)
181+
sys.exit(0)
185182

186183

187184
def main(arg_list: list[str] | None = None):

TTS/bin/compute_statistics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
#!/usr/bin/env python3
22

33
import argparse
4-
import glob
54
import logging
6-
import os
75
import sys
6+
from pathlib import Path
87

98
import numpy as np
109
from tqdm import tqdm
@@ -19,7 +18,7 @@
1918
def parse_args(arg_list: list[str] | None) -> tuple[argparse.Namespace, list[str]]:
2019
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
2120
parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.")
22-
parser.add_argument("out_path", type=str, help="save path (directory and filename).")
21+
parser.add_argument("out_path", type=str, help="save path (directory and filename).", default="scale_stats.npy")
2322
parser.add_argument(
2423
"--data_path",
2524
type=str,
@@ -46,7 +45,7 @@ def main(arg_list: list[str] | None = None):
4645

4746
# load the meta data of target dataset
4847
if args.data_path:
49-
dataset_items = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True)
48+
dataset_items = list(Path(args.data_path).rglob("*.wav"))
5049
else:
5150
dataset_items = load_tts_samples(CONFIG.datasets)[0] # take only train data
5251
print(f" > There are {len(dataset_items)} files.")
@@ -95,6 +94,7 @@ def main(arg_list: list[str] | None = None):
9594
del CONFIG.audio.symmetric_norm
9695
del CONFIG.audio.clip_norm
9796
stats["audio_config"] = CONFIG.audio.to_dict()
97+
Path(output_file_path).parent.mkdir(exist_ok=True, parents=True)
9898
np.save(output_file_path, stats, allow_pickle=True)
9999
print(f" > stats saved to {output_file_path}")
100100
sys.exit(0)

TTS/bin/extract_tts_spectrograms.py

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,72 @@
2727

2828

2929
def parse_args(arg_list: list[str] | None) -> argparse.Namespace:
30-
parser = argparse.ArgumentParser()
31-
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
32-
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
33-
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
34-
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
35-
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
36-
parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
37-
parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True)
30+
parser = argparse.ArgumentParser(
31+
description="""Extract mel spectrograms from audio using teacher forcing with a trained TTS model.
32+
33+
This script loads a trained TTS model and extracts mel spectrograms by running the model with teacher forcing.
34+
This is useful for analyzing model predictions, creating training data for downstream models, or debugging
35+
model behavior. Supports Tacotron, Tacotron2, and Glow-TTS models.
36+
37+
The script will create subdirectories in the output path:
38+
- mel/: Extracted mel spectrograms (.npy files)
39+
- wav/: Original audio files (if --save_audio is enabled)
40+
- wav_gl/: Griffin-Lim reconstructed audio from mels (if --debug is enabled)
41+
- quant/: Quantized audio files (if --quantize_bits > 0)""",
42+
formatter_class=argparse.RawDescriptionHelpFormatter,
43+
epilog="""Example usage:
44+
python extract_tts_spectrograms.py \\
45+
--config_path /path/to/config.json \\
46+
--checkpoint_path /path/to/checkpoint.pth \\
47+
--output_path /path/to/output""",
48+
)
49+
parser.add_argument(
50+
"--config_path",
51+
type=str,
52+
help="Path to the model configuration file (JSON) used during training. "
53+
"This config defines the model architecture, audio parameters, and dataset settings.",
54+
required=True,
55+
)
56+
parser.add_argument(
57+
"--checkpoint_path",
58+
type=str,
59+
help="Path to the trained model checkpoint file (.pth) to be loaded for inference.",
60+
required=True,
61+
)
62+
parser.add_argument(
63+
"--output_path",
64+
type=str,
65+
help="Directory path where extracted mel spectrograms and optional audio files will be saved. "
66+
"Subdirectories will be created automatically.",
67+
default="output_extract_tts_spectrograms",
68+
)
69+
parser.add_argument(
70+
"--debug",
71+
default=False,
72+
action="store_true",
73+
help="Enable debug mode: saves Griffin-Lim reconstructed audio files from the extracted mel spectrograms "
74+
"to wav_gl/ subdirectory for quality inspection.",
75+
)
76+
parser.add_argument(
77+
"--save_audio",
78+
default=False,
79+
action="store_true",
80+
help="Save the original audio files to the wav/ subdirectory alongside the extracted mel spectrograms.",
81+
)
82+
parser.add_argument(
83+
"--quantize_bits",
84+
type=int,
85+
default=0,
86+
help="Bit depth for audio quantization (e.g., 8, 16). If set to a non-zero value, saves quantized versions "
87+
"of audio files to the quant/ subdirectory. Set to 0 (default) to disable quantization.",
88+
)
89+
parser.add_argument(
90+
"--eval",
91+
action=argparse.BooleanOptionalAction,
92+
help="Include evaluation split in processing. When enabled (default), processes both training and evaluation "
93+
"samples. Use --no-eval to process only training samples.",
94+
default=True,
95+
)
3896
return parser.parse_args(arg_list)
3997

4098

@@ -75,19 +133,6 @@ def setup_loader(config: BaseTTSConfig, ap: AudioProcessor, r, speaker_manager:
75133
)
76134

77135

78-
def set_filename(wav_path: str, out_path: Path) -> tuple[Path, Path, Path, Path]:
79-
wav_name = Path(wav_path).stem
80-
(out_path / "quant").mkdir(exist_ok=True, parents=True)
81-
(out_path / "mel").mkdir(exist_ok=True, parents=True)
82-
(out_path / "wav_gl").mkdir(exist_ok=True, parents=True)
83-
(out_path / "wav").mkdir(exist_ok=True, parents=True)
84-
wavq_path = out_path / "quant" / wav_name
85-
mel_path = out_path / "mel" / wav_name
86-
wav_gl_path = out_path / "wav_gl" / f"{wav_name}.wav"
87-
out_wav_path = out_path / "wav" / f"{wav_name}.wav"
88-
return wavq_path, mel_path, wav_gl_path, out_wav_path
89-
90-
91136
def format_data(data):
92137
# setup input data
93138
text_input = data["token_id"]
@@ -213,34 +258,36 @@ def extract_spectrograms(
213258
d_vectors,
214259
)
215260

261+
(output_path / "mel").mkdir(exist_ok=True, parents=True)
216262
for idx in range(text_input.shape[0]):
217-
wav_file_path = item_idx[idx]
263+
wav_file_path = Path(item_idx[idx])
218264
wav = ap.load_wav(wav_file_path)
219-
wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
220265

221266
# quantize and save wav
222267
if quantize_bits > 0:
223-
wavq = quantize(wav, quantize_bits)
224-
np.save(wavq_path, wavq)
268+
wavq = quantize(x=wav, quantize_bits=quantize_bits)
269+
(output_path / "quant").mkdir(exist_ok=True)
270+
np.save(output_path / "quant" / wav_file_path.stem, wavq)
225271

226272
# save TTS mel
227273
mel = model_output[idx]
228274
mel_length = mel_lengths[idx]
229275
mel = mel[:mel_length, :].T
230-
np.save(mel_path, mel)
276+
np.save(output_path / "mel" / wav_file_path.stem, mel)
231277

232-
export_metadata.append([wav_file_path, mel_path])
278+
export_metadata.append(output_path / "mel" / wav_file_path.stem)
233279
if save_audio:
234-
ap.save_wav(wav, wav_path)
280+
(output_path / "wav").mkdir(exist_ok=True)
281+
ap.save_wav(wav, output_path / "wav" / f"{wav_file_path.stem}.wav")
235282

236283
if debug:
237-
print("Audio for debug saved at:", wav_gl_path)
238-
wav = ap.inv_melspectrogram(mel)
239-
ap.save_wav(wav, wav_gl_path)
284+
wav_gl = ap.inv_melspectrogram(mel)
285+
(output_path / "wav_gl").mkdir(exist_ok=True)
286+
ap.save_wav(wav_gl, output_path / "wav_gl" / f"{wav_file_path.stem}.wav")
240287

241288
with (output_path / metadata_name).open("w") as f:
242-
for data in export_metadata:
243-
f.write(f"{data[0] / data[1]}.npy\n")
289+
for path in export_metadata:
290+
f.write(f"{path}.npy\n")
244291

245292

246293
def main(arg_list: list[str] | None = None) -> None:
@@ -264,12 +311,7 @@ def main(arg_list: list[str] | None = None) -> None:
264311
meta_data = meta_data_train + meta_data_eval
265312

266313
# init speaker manager
267-
if config.use_speaker_embedding:
268-
speaker_manager = SpeakerManager(data_items=meta_data)
269-
elif config.use_d_vector_file:
270-
speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
271-
else:
272-
speaker_manager = None
314+
speaker_manager = SpeakerManager.init_from_config(config)
273315

274316
# setup model
275317
model = setup_model(config)

TTS/config/__init__.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -108,32 +108,7 @@ def load_config(config_path: str | os.PathLike[Any]) -> BaseTrainingConfig:
108108
return config
109109

110110

111-
def check_config_and_model_args(config, arg_name, value):
112-
"""Check the give argument in `config.model_args` if exist or in `config` for
113-
the given value.
114-
115-
Return False if the argument does not exist in `config.model_args` or `config`.
116-
This is to patch up the compatibility between models with and without `model_args`.
117-
118-
TODO: Remove this in the future with a unified approach.
119-
"""
120-
if getattr(config, "model_args", None) is not None:
121-
if arg_name in config.model_args:
122-
return config.model_args[arg_name] == value
123-
if hasattr(config, arg_name):
124-
return config[arg_name] == value
125-
return False
126-
127-
128-
def get_from_config_or_model_args(config, arg_name):
129-
"""Get the given argument from `config.model_args` if exist or in `config`."""
130-
if getattr(config, "model_args", None) is not None:
131-
if arg_name in config.model_args:
132-
return config.model_args[arg_name]
133-
return config[arg_name]
134-
135-
136-
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
111+
def get_from_config_or_model_args(config: Coqpit, arg_name: str, def_val: Any = None) -> Any:
137112
"""Get the given argument from `config.model_args` if exist or in `config`."""
138113
if getattr(config, "model_args", None) is not None:
139114
if arg_name in config.model_args:

TTS/config/shared_configs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,11 @@ class BaseDatasetConfig(Coqpit):
222222
train the duration predictor.
223223
"""
224224

225-
formatter: str = ""
226-
dataset_name: str = ""
227-
path: str = ""
225+
formatter: str | None = ""
226+
dataset_name: str | None = ""
227+
path: str | None = ""
228228
meta_file_train: str = ""
229-
ignored_speakers: list[str] = None
229+
ignored_speakers: list[str] | None = None
230230
language: str = ""
231231
phonemizer: str = ""
232232
meta_file_val: str = ""

TTS/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from trainer import TrainerModel
88
from trainer.io import load_fsspec
99

10-
# pylint: skip-file
11-
1210

1311
class BaseTrainerModel(TrainerModel):
1412
"""BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS.
@@ -29,7 +27,7 @@ def init_from_config(config: Coqpit) -> "BaseTrainerModel":
2927
def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
3028
"""Forward pass for inference.
3129
32-
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
30+
Must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
3331
is considered to be the main output and you can add any other auxiliary outputs as you want.
3432
3533
We don't use `*kwargs` since it is problematic with the TorchScript API.
@@ -40,6 +38,7 @@ def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict
4038
4139
Returns:
4240
Dict: [description]
41+
4342
"""
4443
outputs_dict = {"model_outputs": None}
4544
...
@@ -53,6 +52,7 @@ def load_checkpoint(
5352
eval: bool = False,
5453
strict: bool = True,
5554
cache: bool = False,
55+
**kwargs: Any,
5656
) -> None:
5757
"""Load a model checkpoint file and get ready for training or inference.
5858
@@ -63,6 +63,7 @@ def load_checkpoint(
6363
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
6464
cache (bool, optional): If True, cache the file locally for subsequent calls.
6565
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
66+
6667
"""
6768
state = load_fsspec(checkpoint_path, map_location="cpu", cache=cache)
6869
self.load_state_dict(state["model"], strict=strict)
@@ -71,4 +72,5 @@ def load_checkpoint(
7172

7273
@property
7374
def device(self) -> torch.device:
75+
"""Return device of the model based on its parameters."""
7476
return next(self.parameters()).device

TTS/tts/configs/shared_configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
BaseAudioConfig,
88
BaseDatasetConfig,
99
BaseTrainingConfig,
10-
get_from_config_or_model_args_with_default,
10+
get_from_config_or_model_args,
1111
)
1212

1313

@@ -357,6 +357,6 @@ class BaseTTSConfig(BaseTrainingConfig):
357357
@property
358358
def supports_cloning(self) -> bool:
359359
return self._supports_cloning or (
360-
Path(get_from_config_or_model_args_with_default(self, "speaker_encoder_model_path", "")).is_file()
361-
and Path(get_from_config_or_model_args_with_default(self, "speaker_encoder_config_path", "")).is_file()
360+
Path(get_from_config_or_model_args(self, "speaker_encoder_model_path", "")).is_file()
361+
and Path(get_from_config_or_model_args(self, "speaker_encoder_config_path", "")).is_file()
362362
)

TTS/tts/configs/vits_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,15 @@ class VitsConfig(BaseTTSConfig):
159159
# use speaker embedding layer
160160
num_speakers: int = 0
161161
use_speaker_embedding: bool = False
162-
speakers_file: str = None
162+
speakers_file: str | None = None
163163
speaker_embedding_channels: int = 256
164-
language_ids_file: str = None
164+
language_ids_file: str | None = None
165165
use_language_embedding: bool = False
166166

167167
# use d-vectors
168168
use_d_vector_file: bool = False
169-
d_vector_file: list[str] = None
170-
d_vector_dim: int = None
169+
d_vector_file: str | list[str] | None = None
170+
d_vector_dim: int | None = None
171171

172172
def __post_init__(self):
173173
for key, val in self.model_args.items():

0 commit comments

Comments
 (0)