""" TTS 引擎工厂类 """ from enum import Enum from typing import Optional from .base import TTSEngine from .edge_tts_engine import EdgeTTSEngine from .cosyvoice_engine import CosyVoiceEngine from utils.logger import logger class TTSEngineType(Enum): """支持的 TTS 引擎类型""" EDGE_TTS = "edge-tts" COSYVOICE = "cosyvoice" # 可以在这里添加更多引擎类型 # GOOGLE_TTS = "google-tts" # BAIDU_TTS = "baidu-tts" # AZURE_TTS = "azure-tts" class TTSEngineFactory: """ TTS 引擎工厂 负责创建和管理 TTS 引擎实例。支持多引擎扩展。 """ _engines = { TTSEngineType.EDGE_TTS: EdgeTTSEngine, TTSEngineType.COSYVOICE: CosyVoiceEngine, # 添加其他引擎实现时在这里注册 } _instances: dict[TTSEngineType, TTSEngine] = {} @classmethod def create(cls, engine_type: str | TTSEngineType) -> TTSEngine: """ 创建 TTS 引擎实例(单例模式) Args: engine_type: 引擎类型,可以是字符串或 TTSEngineType 枚举 Returns: TTSEngine 实例 Raises: ValueError: 如果指定的引擎类型不支持 """ # 转换为 TTSEngineType if isinstance(engine_type, str): try: engine_type = TTSEngineType(engine_type) except ValueError: raise ValueError( f"Unsupported TTS engine type: {engine_type}. " f"Supported types: {[e.value for e in TTSEngineType]}" ) # 返回已缓存的实例或创建新实例 if engine_type not in cls._instances: if engine_type not in cls._engines: raise ValueError( f"TTS engine '{engine_type.value}' is not registered. " f"Available engines: {list(cls._engines.keys())}" ) engine_class = cls._engines[engine_type] instance = engine_class() cls._instances[engine_type] = instance logger.info(f"Created TTS engine instance: {engine_type.value}") return cls._instances[engine_type] @classmethod def register_engine( cls, engine_type: str | TTSEngineType, engine_class: type[TTSEngine] ) -> None: """ 注册新的 TTS 引擎类型 Args: engine_type: 引擎类型标识 engine_class: 引擎类,必须继承 TTSEngine Raises: TypeError: 如果 engine_class 不是 TTSEngine 的子类 """ if not issubclass(engine_class, TTSEngine): raise TypeError(f"{engine_class} must be a subclass of TTSEngine") # 转换为 TTSEngineType if isinstance(engine_type, str): engine_type = TTSEngineType(engine_type) cls._engines[engine_type] = engine_class logger.info(f"Registered TTS engine: {engine_type.value}") @classmethod def get_supported_engines(cls) -> list[str]: """ 获取所有支持的引擎类型 Returns: 支持的引擎类型列表 """ return [e.value for e in TTSEngineType] @classmethod def clear_instances(cls) -> None: """清空所有引擎实例缓存""" cls._instances.clear() logger.debug("Cleared TTS engine instances cache")