Files
meme/tts/factory.py
konjacpotato 6772699cfe
Some checks failed
Gitea Actions Demo / deploy (push) Failing after 2s
commit code
2025-12-29 19:34:39 +08:00

116 lines
3.3 KiB
Python

"""
TTS 引擎工厂类
"""
from enum import Enum
from typing import Optional
from .base import TTSEngine
from .edge_tts_engine import EdgeTTSEngine
from .cosyvoice_engine import CosyVoiceEngine
from utils.logger import logger
class TTSEngineType(Enum):
"""支持的 TTS 引擎类型"""
EDGE_TTS = "edge-tts"
COSYVOICE = "cosyvoice"
# 可以在这里添加更多引擎类型
# GOOGLE_TTS = "google-tts"
# BAIDU_TTS = "baidu-tts"
# AZURE_TTS = "azure-tts"
class TTSEngineFactory:
"""
TTS 引擎工厂
负责创建和管理 TTS 引擎实例。支持多引擎扩展。
"""
_engines = {
TTSEngineType.EDGE_TTS: EdgeTTSEngine,
TTSEngineType.COSYVOICE: CosyVoiceEngine,
# 添加其他引擎实现时在这里注册
}
_instances: dict[TTSEngineType, TTSEngine] = {}
@classmethod
def create(cls, engine_type: str | TTSEngineType) -> TTSEngine:
"""
创建 TTS 引擎实例(单例模式)
Args:
engine_type: 引擎类型,可以是字符串或 TTSEngineType 枚举
Returns:
TTSEngine 实例
Raises:
ValueError: 如果指定的引擎类型不支持
"""
# 转换为 TTSEngineType
if isinstance(engine_type, str):
try:
engine_type = TTSEngineType(engine_type)
except ValueError:
raise ValueError(
f"Unsupported TTS engine type: {engine_type}. "
f"Supported types: {[e.value for e in TTSEngineType]}"
)
# 返回已缓存的实例或创建新实例
if engine_type not in cls._instances:
if engine_type not in cls._engines:
raise ValueError(
f"TTS engine '{engine_type.value}' is not registered. "
f"Available engines: {list(cls._engines.keys())}"
)
engine_class = cls._engines[engine_type]
instance = engine_class()
cls._instances[engine_type] = instance
logger.info(f"Created TTS engine instance: {engine_type.value}")
return cls._instances[engine_type]
@classmethod
def register_engine(
cls, engine_type: str | TTSEngineType, engine_class: type[TTSEngine]
) -> None:
"""
注册新的 TTS 引擎类型
Args:
engine_type: 引擎类型标识
engine_class: 引擎类,必须继承 TTSEngine
Raises:
TypeError: 如果 engine_class 不是 TTSEngine 的子类
"""
if not issubclass(engine_class, TTSEngine):
raise TypeError(f"{engine_class} must be a subclass of TTSEngine")
# 转换为 TTSEngineType
if isinstance(engine_type, str):
engine_type = TTSEngineType(engine_type)
cls._engines[engine_type] = engine_class
logger.info(f"Registered TTS engine: {engine_type.value}")
@classmethod
def get_supported_engines(cls) -> list[str]:
"""
获取所有支持的引擎类型
Returns:
支持的引擎类型列表
"""
return [e.value for e in TTSEngineType]
@classmethod
def clear_instances(cls) -> None:
"""清空所有引擎实例缓存"""
cls._instances.clear()
logger.debug("Cleared TTS engine instances cache")