116 lines
3.3 KiB
Python
116 lines
3.3 KiB
Python
"""
|
|
TTS 引擎工厂类
|
|
"""
|
|
from enum import Enum
|
|
from typing import Optional
|
|
from .base import TTSEngine
|
|
from .edge_tts_engine import EdgeTTSEngine
|
|
from .cosyvoice_engine import CosyVoiceEngine
|
|
from utils.logger import logger
|
|
|
|
|
|
class TTSEngineType(Enum):
|
|
"""支持的 TTS 引擎类型"""
|
|
|
|
EDGE_TTS = "edge-tts"
|
|
COSYVOICE = "cosyvoice"
|
|
# 可以在这里添加更多引擎类型
|
|
# GOOGLE_TTS = "google-tts"
|
|
# BAIDU_TTS = "baidu-tts"
|
|
# AZURE_TTS = "azure-tts"
|
|
|
|
|
|
class TTSEngineFactory:
|
|
"""
|
|
TTS 引擎工厂
|
|
|
|
负责创建和管理 TTS 引擎实例。支持多引擎扩展。
|
|
"""
|
|
|
|
_engines = {
|
|
TTSEngineType.EDGE_TTS: EdgeTTSEngine,
|
|
TTSEngineType.COSYVOICE: CosyVoiceEngine,
|
|
# 添加其他引擎实现时在这里注册
|
|
}
|
|
|
|
_instances: dict[TTSEngineType, TTSEngine] = {}
|
|
|
|
@classmethod
|
|
def create(cls, engine_type: str | TTSEngineType) -> TTSEngine:
|
|
"""
|
|
创建 TTS 引擎实例(单例模式)
|
|
|
|
Args:
|
|
engine_type: 引擎类型,可以是字符串或 TTSEngineType 枚举
|
|
|
|
Returns:
|
|
TTSEngine 实例
|
|
|
|
Raises:
|
|
ValueError: 如果指定的引擎类型不支持
|
|
"""
|
|
# 转换为 TTSEngineType
|
|
if isinstance(engine_type, str):
|
|
try:
|
|
engine_type = TTSEngineType(engine_type)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Unsupported TTS engine type: {engine_type}. "
|
|
f"Supported types: {[e.value for e in TTSEngineType]}"
|
|
)
|
|
|
|
# 返回已缓存的实例或创建新实例
|
|
if engine_type not in cls._instances:
|
|
if engine_type not in cls._engines:
|
|
raise ValueError(
|
|
f"TTS engine '{engine_type.value}' is not registered. "
|
|
f"Available engines: {list(cls._engines.keys())}"
|
|
)
|
|
|
|
engine_class = cls._engines[engine_type]
|
|
instance = engine_class()
|
|
cls._instances[engine_type] = instance
|
|
logger.info(f"Created TTS engine instance: {engine_type.value}")
|
|
|
|
return cls._instances[engine_type]
|
|
|
|
@classmethod
|
|
def register_engine(
|
|
cls, engine_type: str | TTSEngineType, engine_class: type[TTSEngine]
|
|
) -> None:
|
|
"""
|
|
注册新的 TTS 引擎类型
|
|
|
|
Args:
|
|
engine_type: 引擎类型标识
|
|
engine_class: 引擎类,必须继承 TTSEngine
|
|
|
|
Raises:
|
|
TypeError: 如果 engine_class 不是 TTSEngine 的子类
|
|
"""
|
|
if not issubclass(engine_class, TTSEngine):
|
|
raise TypeError(f"{engine_class} must be a subclass of TTSEngine")
|
|
|
|
# 转换为 TTSEngineType
|
|
if isinstance(engine_type, str):
|
|
engine_type = TTSEngineType(engine_type)
|
|
|
|
cls._engines[engine_type] = engine_class
|
|
logger.info(f"Registered TTS engine: {engine_type.value}")
|
|
|
|
@classmethod
|
|
def get_supported_engines(cls) -> list[str]:
|
|
"""
|
|
获取所有支持的引擎类型
|
|
|
|
Returns:
|
|
支持的引擎类型列表
|
|
"""
|
|
return [e.value for e in TTSEngineType]
|
|
|
|
@classmethod
|
|
def clear_instances(cls) -> None:
|
|
"""清空所有引擎实例缓存"""
|
|
cls._instances.clear()
|
|
logger.debug("Cleared TTS engine instances cache")
|