Skip to content

Commit 69704ee

Browse files
authored
Merge pull request #265 from idiap/dev
v0.25.3
2 parents 2b694c1 + 44c3491 commit 69704ee

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

TTS/tts/models/vits.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import os
44
from dataclasses import dataclass, field, replace
55
from itertools import chain
6-
from typing import Dict, List, Tuple, Union
6+
from pathlib import Path
7+
from typing import Any, Dict, List, Tuple, Union
78

89
import numpy as np
910
import torch
@@ -1581,13 +1582,16 @@ def load_fairseq_checkpoint(
15811582

15821583
self.disc = None
15831584
# set paths
1584-
config_file = os.path.join(checkpoint_dir, "config.json")
1585-
checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth")
1586-
vocab_file = os.path.join(checkpoint_dir, "vocab.txt")
1585+
checkpoint_dir = Path(checkpoint_dir)
1586+
config_file = checkpoint_dir / "config.json"
1587+
checkpoint_file = checkpoint_dir / "model.pth"
1588+
if not checkpoint_file.is_file():
1589+
checkpoint_file = checkpoint_dir / "G_100000.pth"
1590+
vocab_file = checkpoint_dir / "vocab.txt"
15871591
# set config params
1588-
with open(config_file, "r", encoding="utf-8") as file:
1592+
with open(config_file, "r", encoding="utf-8") as f:
15891593
# Load the JSON data as a dictionary
1590-
config_org = json.load(file)
1594+
config_org = json.load(f)
15911595
self.config.audio.sample_rate = config_org["data"]["sampling_rate"]
15921596
# self.config.add_blank = config['add_blank']
15931597
# set tokenizer
@@ -1821,7 +1825,7 @@ def to_config(self) -> "CharactersConfig":
18211825

18221826

18231827
class FairseqVocab(BaseVocabulary):
1824-
def __init__(self, vocab: str):
1828+
def __init__(self, vocab: Union[str, os.PathLike[Any]]):
18251829
super(FairseqVocab).__init__()
18261830
self.vocab = vocab
18271831

@@ -1831,7 +1835,7 @@ def vocab(self):
18311835
return self._vocab
18321836

18331837
@vocab.setter
1834-
def vocab(self, vocab_file):
1838+
def vocab(self, vocab_file: Union[str, os.PathLike[Any]]):
18351839
with open(vocab_file, encoding="utf-8") as f:
18361840
self._vocab = [x.replace("\n", "") for x in f.readlines()]
18371841
self.blank = self._vocab[0]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ build-backend = "hatchling.build"
2525

2626
[project]
2727
name = "coqui-tts"
28-
version = "0.25.2"
28+
version = "0.25.3"
2929
description = "Deep learning for Text to Speech."
3030
readme = "README.md"
3131
requires-python = ">=3.9, <3.13"

tests/zoo_tests/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def manager(tmp_path):
3737
num_partitions = int(os.getenv("NUM_PARTITIONS", "1"))
3838
partition = int(os.getenv("TEST_PARTITION", "0"))
3939
model_names = [name for name in TTS.list_models() if name not in MODELS_WITH_SEP_TESTS]
40+
model_names.extend(["tts_models/deu/fairseq/vits", "tts_models/sqi/fairseq/vits"])
4041
model_names = [name for i, name in enumerate(model_names) if i % num_partitions == partition]
4142

4243

0 commit comments

Comments
 (0)