Skip to content

Commit 3250edd

Browse files
authored
Merge pull request #525 from idiap/torch29
fix: support pytorch 2.9
2 parents 268b532 + 1dc1820 commit 3250edd

File tree

8 files changed

+44
-15
lines changed

8 files changed

+44
-15
lines changed

.github/workflows/tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676
if [ "${{ matrix.python-version }}" == "3.10" ]; then
7777
resolution=lowest-direct
7878
fi
79-
uv run --resolution=$resolution --extra server --extra languages make ${{ matrix.subset }}
79+
uv run --resolution=$resolution --extra codec --extra server --extra languages make ${{ matrix.subset }}
8080
- name: Upload coverage data
8181
uses: actions/upload-artifact@v4
8282
with:
@@ -119,7 +119,7 @@ jobs:
119119
if [ "${{ matrix.python-version }}" == "3.10" ]; then
120120
resolution=lowest-direct
121121
fi
122-
uv run --resolution=$resolution --extra languages coverage run -m pytest -x -v --durations=0 $shard_tests
122+
uv run --resolution=$resolution --extra codec --extra languages coverage run -m pytest -x -v --durations=0 $shard_tests
123123
- name: Upload coverage data
124124
uses: actions/upload-artifact@v4
125125
with:
@@ -154,7 +154,7 @@ jobs:
154154
uv add git+https://github.com/idiap/coqui-ai-coqpit --branch ${{ github.event.inputs.coqpit_branch }}
155155
fi
156156
- name: Zoo tests
157-
run: uv run --extra server --extra languages make test_zoo
157+
run: uv run --extra codec --extra server --extra languages make test_zoo
158158
env:
159159
NUM_PARTITIONS: 3
160160
TEST_PARTITION: ${{ matrix.partition }}

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ lint: ## run linters.
4444
uv run --only-dev ruff format ${target_dirs} --check
4545

4646
system-deps: ## install linux system deps
47-
sudo apt-get install -y libsndfile1-dev
47+
sudo apt-get install -y libsndfile1-dev ffmpeg
4848

4949
install: ## install 🐸 TTS
5050
uv sync --all-extras

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ You can also help us implement more models.
118118
## Installation
119119

120120
🐸TTS is tested on Ubuntu 24.04 with **python >= 3.10, < 3.14**, but should also
121-
work on Mac and Windows.
121+
work on Mac and Windows. Depending on your platform, you might first want to
122+
separately install Pytorch, `torchaudio`, and `torchcodec` with their
123+
[official instructions](https://pytorch.org/get-started/locally/).
122124

123125
If you are only interested in [synthesizing speech](https://coqui-tts.readthedocs.io/en/latest/inference.html) with the pretrained 🐸TTS models, installing from PyPI is the easiest option.
124126

@@ -141,6 +143,7 @@ The following extras allow the installation of optional dependencies:
141143
| Name | Description |
142144
|------|-------------|
143145
| `all` | All optional dependencies |
146+
| `codec` | Installs torchcodec needed with Pytorch>=2.9 |
144147
| `notebooks` | Dependencies only used in notebooks |
145148
| `server` | Dependencies to run the TTS server |
146149
| `bn` | Bangla G2P |

TTS/tts/configs/vits_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class VitsConfig(BaseTTSConfig):
4242
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
4343
4444
scheduler_after_epoch (bool):
45-
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
45+
If true, step the schedulers after each epoch else after each step. Defaults to `True`.
4646
4747
optimizer (str):
4848
Name of the optimizer to use with both the generator and the discriminator networks. One of the

TTS/tts/datasets/dataset.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
import logging
44
import os
55
import random
6+
from math import floor
67
from typing import Any
78

89
import numpy as np
910
import numpy.typing as npt
1011
import torch
11-
import torchaudio
1212
import tqdm
1313
from torch.utils.data import Dataset
1414

1515
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
1616
from TTS.utils.audio import AudioProcessor
1717
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
18+
from TTS.utils.generic_utils import is_pytorch_at_least_2_9
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -47,6 +48,20 @@ def string2filename(string: str) -> str:
4748
return base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
4849

4950

51+
def _get_audio_size_torchcodec(audiopath: str | os.PathLike[Any]) -> int:
52+
try:
53+
from torchcodec.decoders import AudioDecoder
54+
except ImportError as e:
55+
msg = "torchcodec not installed (available in the `codec` extra)"
56+
raise ImportError(msg) from e
57+
except RuntimeError as e:
58+
msg = "Error while importing torchcodec, see the stacktrace for details."
59+
raise ImportError(msg) from e
60+
61+
metadata = AudioDecoder(audiopath).metadata
62+
return floor(metadata.duration_seconds_from_header * metadata.sample_rate)
63+
64+
5065
def get_audio_size(audiopath: str | os.PathLike[Any]) -> int:
5166
"""Return the number of samples in the audio file."""
5267
if not isinstance(audiopath, str):
@@ -57,7 +72,12 @@ def get_audio_size(audiopath: str | os.PathLike[Any]) -> int:
5772
raise RuntimeError(msg)
5873

5974
try:
60-
return torchaudio.info(audiopath).num_frames
75+
if is_pytorch_at_least_2_9():
76+
return _get_audio_size_torchcodec(audiopath)
77+
else:
78+
import torchaudio
79+
80+
return torchaudio.info(audiopath).num_frames
6181
except RuntimeError as e:
6282
msg = f"Failed to decode {audiopath}"
6383
raise RuntimeError(msg) from e

TTS/utils/generic_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ def is_pytorch_at_least_2_4() -> bool:
161161
return Version(torch.__version__) >= Version("2.4")
162162

163163

164+
def is_pytorch_at_least_2_9() -> bool:
165+
"""Check if the installed Pytorch version is 2.4 or higher."""
166+
return Version(torch.__version__) >= Version("2.9")
167+
168+
164169
def optional_to_str(x: Any | None) -> str:
165170
"""Convert input to string, using empty string if input is None."""
166171
return "" if x is None else str(x)

TTS/vc/configs/freevc_config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,6 @@ class FreeVCConfig(BaseVCConfig):
172172
lr_scheduler_disc_params (dict):
173173
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
174174
175-
scheduler_after_epoch (bool):
176-
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
177-
178175
optimizer (str):
179176
Name of the optimizer to use with both the generator and the discriminator networks. One of the
180177
`torch.optim.*`. Defaults to `AdamW`.

pyproject.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ build-backend = "hatchling.build"
2525

2626
[project]
2727
name = "coqui-tts"
28-
version = "0.27.2"
28+
version = "0.27.3"
2929
description = "Deep learning for Text to Speech."
3030
readme = "README.md"
3131
requires-python = ">=3.10, <3.14"
@@ -59,8 +59,8 @@ dependencies = [
5959
# Core
6060
"numpy>=1.26.0",
6161
"scipy>=1.13.0",
62-
"torch>=2.2,<2.9",
63-
"torchaudio>=2.2.0,<2.9",
62+
"torch>=2.2",
63+
"torchaudio>=2.2.0",
6464
"soundfile>=0.12.0",
6565
"librosa>=0.11.0",
6666
"numba>=0.58.0",
@@ -92,6 +92,10 @@ dependencies = [
9292
]
9393

9494
[project.optional-dependencies]
95+
# torchcodec needed from torch>=2.9
96+
codec = [
97+
"torchcodec>=0.8.0",
98+
]
9599
# Only used in notebooks
96100
notebooks = [
97101
"bokeh>=3.0.3",
@@ -128,7 +132,7 @@ languages = [
128132
]
129133
# Installs all extras (except dev and docs)
130134
all = [
131-
"coqui-tts[notebooks,server,bn,ja,ko,zh]",
135+
"coqui-tts[codec,notebooks,server,bn,ja,ko,zh]",
132136
]
133137

134138
[dependency-groups]

0 commit comments

Comments
 (0)