209 lines
5.7 KiB
Python
209 lines
5.7 KiB
Python
"""
|
|
CosyVoice 集成测试文件
|
|
|
|
测试 CosyVoice 引擎的基本功能
|
|
"""
|
|
import asyncio
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# 确保可以导入项目模块
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
|
|
async def test_cosyvoice_factory():
|
|
"""测试使用工厂模式创建 CosyVoice 引擎"""
|
|
print("\n" + "=" * 60)
|
|
print("测试 1: 工厂模式创建 CosyVoice 引擎")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
from tts.factory import TTSEngineFactory
|
|
|
|
# 创建引擎
|
|
engine = TTSEngineFactory.create("cosyvoice")
|
|
print(f"✓ 引擎创建成功: {engine.get_engine_name()}")
|
|
print(f" 版本: {engine.get_engine_version()}")
|
|
|
|
# 获取示例声音
|
|
voices = await engine.get_supported_voices()
|
|
print(f"✓ 获取示例声音列表: {len(voices)} 个")
|
|
for voice in voices:
|
|
print(f" - {voice['name']}: {voice['voice_id']}")
|
|
|
|
except Exception as e:
|
|
print(f"✗ 错误: {e}")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
async def test_cosyvoice_direct():
|
|
"""测试直接创建 CosyVoice 引擎实例"""
|
|
print("\n" + "=" * 60)
|
|
print("测试 2: 直接创建 CosyVoice 引擎实例")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
from tts.cosyvoice_engine import CosyVoiceEngine
|
|
|
|
# 创建引擎实例
|
|
engine = CosyVoiceEngine(
|
|
api_url="http://192.168.1.200:8000/tts/zero_shot",
|
|
timeout=30.0,
|
|
)
|
|
print(f"✓ 引擎实例创建成功")
|
|
print(f" 名称: {engine.get_engine_name()}")
|
|
print(f" 版本: {engine.get_engine_version()}")
|
|
print(f" API URL: http://192.168.1.200:8000/tts/zero_shot")
|
|
|
|
# 关闭连接
|
|
await engine.close()
|
|
print(f"✓ HTTP 客户端连接已关闭")
|
|
|
|
except Exception as e:
|
|
print(f"✗ 错误: {e}")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
async def test_synthesize_without_voice():
|
|
"""测试缺少 voice 参数时的错误处理"""
|
|
print("\n" + "=" * 60)
|
|
print("测试 3: 验证 voice 参数是否为必需")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
from tts.factory import TTSEngineFactory
|
|
|
|
engine = TTSEngineFactory.create("cosyvoice")
|
|
|
|
# 尝试不提供 voice 参数
|
|
try:
|
|
await engine.synthesize("测试文本")
|
|
print("✗ 应该抛出 ValueError")
|
|
return False
|
|
except ValueError as e:
|
|
print(f"✓ 正确抛出 ValueError: {e}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ 意外错误: {e}")
|
|
return False
|
|
|
|
|
|
async def test_available_engines():
|
|
"""测试工厂支持的所有引擎"""
|
|
print("\n" + "=" * 60)
|
|
print("测试 4: 检查支持的引擎列表")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
from tts.factory import TTSEngineFactory
|
|
|
|
engines = TTSEngineFactory.get_supported_engines()
|
|
print(f"✓ 支持的引擎列表:")
|
|
for engine_name in engines:
|
|
print(f" - {engine_name}")
|
|
|
|
# 验证 cosyvoice 在列表中
|
|
if "cosyvoice" in engines:
|
|
print(f"✓ cosyvoice 已注册到工厂")
|
|
return True
|
|
else:
|
|
print(f"✗ cosyvoice 未在支持列表中")
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"✗ 错误: {e}")
|
|
return False
|
|
|
|
|
|
async def test_engine_comparison():
|
|
"""测试引擎之间的差异"""
|
|
print("\n" + "=" * 60)
|
|
print("测试 5: 引擎对比")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
from tts.factory import TTSEngineFactory
|
|
|
|
engines_to_test = ["edge-tts", "cosyvoice"]
|
|
results = {}
|
|
|
|
for engine_name in engines_to_test:
|
|
try:
|
|
engine = TTSEngineFactory.create(engine_name)
|
|
results[engine_name] = {
|
|
"name": engine.get_engine_name(),
|
|
"version": engine.get_engine_version(),
|
|
"status": "✓ 已注册",
|
|
}
|
|
except ValueError as e:
|
|
results[engine_name] = {
|
|
"status": f"✗ {e}",
|
|
}
|
|
|
|
print("\n引擎对比表:")
|
|
print(f"{'引擎名称':<15} {'状态':<20}")
|
|
print("-" * 35)
|
|
for engine_name, info in results.items():
|
|
print(f"{engine_name:<15} {info['status']:<20}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ 错误: {e}")
|
|
return False
|
|
|
|
|
|
async def main():
|
|
"""运行所有测试"""
|
|
print("\n")
|
|
print("╔" + "=" * 58 + "╗")
|
|
print("║" + " " * 58 + "║")
|
|
print("║" + " CosyVoice 引擎集成测试".center(58) + "║")
|
|
print("║" + " " * 58 + "║")
|
|
print("╚" + "=" * 58 + "╝")
|
|
|
|
tests = [
|
|
("工厂模式创建", test_cosyvoice_factory),
|
|
("直接创建实例", test_cosyvoice_direct),
|
|
("参数验证", test_synthesize_without_voice),
|
|
("支持的引擎", test_available_engines),
|
|
("引擎对比", test_engine_comparison),
|
|
]
|
|
|
|
results = []
|
|
for test_name, test_func in tests:
|
|
try:
|
|
result = await test_func()
|
|
results.append((test_name, result))
|
|
except Exception as e:
|
|
print(f"\n✗ 测试异常: {e}")
|
|
results.append((test_name, False))
|
|
|
|
# 打印测试总结
|
|
print("\n" + "=" * 60)
|
|
print("测试总结")
|
|
print("=" * 60)
|
|
|
|
passed = sum(1 for _, result in results if result)
|
|
total = len(results)
|
|
|
|
for test_name, result in results:
|
|
status = "✓ 通过" if result else "✗ 失败"
|
|
print(f"{status} {test_name}")
|
|
|
|
print("-" * 60)
|
|
print(f"总计: {passed}/{total} 通过")
|
|
print("=" * 60)
|
|
|
|
return passed == total
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = asyncio.run(main())
|
|
sys.exit(0 if success else 1)
|