This commit is contained in:
208
tts/test_cosyvoice.py
Normal file
208
tts/test_cosyvoice.py
Normal file
@ -0,0 +1,208 @@
|
||||
"""
|
||||
CosyVoice 集成测试文件
|
||||
|
||||
测试 CosyVoice 引擎的基本功能
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 确保可以导入项目模块
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
async def test_cosyvoice_factory():
|
||||
"""测试使用工厂模式创建 CosyVoice 引擎"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试 1: 工厂模式创建 CosyVoice 引擎")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from tts.factory import TTSEngineFactory
|
||||
|
||||
# 创建引擎
|
||||
engine = TTSEngineFactory.create("cosyvoice")
|
||||
print(f"✓ 引擎创建成功: {engine.get_engine_name()}")
|
||||
print(f" 版本: {engine.get_engine_version()}")
|
||||
|
||||
# 获取示例声音
|
||||
voices = await engine.get_supported_voices()
|
||||
print(f"✓ 获取示例声音列表: {len(voices)} 个")
|
||||
for voice in voices:
|
||||
print(f" - {voice['name']}: {voice['voice_id']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 错误: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_cosyvoice_direct():
|
||||
"""测试直接创建 CosyVoice 引擎实例"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试 2: 直接创建 CosyVoice 引擎实例")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from tts.cosyvoice_engine import CosyVoiceEngine
|
||||
|
||||
# 创建引擎实例
|
||||
engine = CosyVoiceEngine(
|
||||
api_url="http://192.168.1.200:8000/tts/zero_shot",
|
||||
timeout=30.0,
|
||||
)
|
||||
print(f"✓ 引擎实例创建成功")
|
||||
print(f" 名称: {engine.get_engine_name()}")
|
||||
print(f" 版本: {engine.get_engine_version()}")
|
||||
print(f" API URL: http://192.168.1.200:8000/tts/zero_shot")
|
||||
|
||||
# 关闭连接
|
||||
await engine.close()
|
||||
print(f"✓ HTTP 客户端连接已关闭")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 错误: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_synthesize_without_voice():
|
||||
"""测试缺少 voice 参数时的错误处理"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试 3: 验证 voice 参数是否为必需")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from tts.factory import TTSEngineFactory
|
||||
|
||||
engine = TTSEngineFactory.create("cosyvoice")
|
||||
|
||||
# 尝试不提供 voice 参数
|
||||
try:
|
||||
await engine.synthesize("测试文本")
|
||||
print("✗ 应该抛出 ValueError")
|
||||
return False
|
||||
except ValueError as e:
|
||||
print(f"✓ 正确抛出 ValueError: {e}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 意外错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_available_engines():
|
||||
"""测试工厂支持的所有引擎"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试 4: 检查支持的引擎列表")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from tts.factory import TTSEngineFactory
|
||||
|
||||
engines = TTSEngineFactory.get_supported_engines()
|
||||
print(f"✓ 支持的引擎列表:")
|
||||
for engine_name in engines:
|
||||
print(f" - {engine_name}")
|
||||
|
||||
# 验证 cosyvoice 在列表中
|
||||
if "cosyvoice" in engines:
|
||||
print(f"✓ cosyvoice 已注册到工厂")
|
||||
return True
|
||||
else:
|
||||
print(f"✗ cosyvoice 未在支持列表中")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_engine_comparison():
|
||||
"""测试引擎之间的差异"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试 5: 引擎对比")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from tts.factory import TTSEngineFactory
|
||||
|
||||
engines_to_test = ["edge-tts", "cosyvoice"]
|
||||
results = {}
|
||||
|
||||
for engine_name in engines_to_test:
|
||||
try:
|
||||
engine = TTSEngineFactory.create(engine_name)
|
||||
results[engine_name] = {
|
||||
"name": engine.get_engine_name(),
|
||||
"version": engine.get_engine_version(),
|
||||
"status": "✓ 已注册",
|
||||
}
|
||||
except ValueError as e:
|
||||
results[engine_name] = {
|
||||
"status": f"✗ {e}",
|
||||
}
|
||||
|
||||
print("\n引擎对比表:")
|
||||
print(f"{'引擎名称':<15} {'状态':<20}")
|
||||
print("-" * 35)
|
||||
for engine_name, info in results.items():
|
||||
print(f"{engine_name:<15} {info['status']:<20}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""运行所有测试"""
|
||||
print("\n")
|
||||
print("╔" + "=" * 58 + "╗")
|
||||
print("║" + " " * 58 + "║")
|
||||
print("║" + " CosyVoice 引擎集成测试".center(58) + "║")
|
||||
print("║" + " " * 58 + "║")
|
||||
print("╚" + "=" * 58 + "╝")
|
||||
|
||||
tests = [
|
||||
("工厂模式创建", test_cosyvoice_factory),
|
||||
("直接创建实例", test_cosyvoice_direct),
|
||||
("参数验证", test_synthesize_without_voice),
|
||||
("支持的引擎", test_available_engines),
|
||||
("引擎对比", test_engine_comparison),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
result = await test_func()
|
||||
results.append((test_name, result))
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试异常: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
# 打印测试总结
|
||||
print("\n" + "=" * 60)
|
||||
print("测试总结")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for _, result in results if result)
|
||||
total = len(results)
|
||||
|
||||
for test_name, result in results:
|
||||
status = "✓ 通过" if result else "✗ 失败"
|
||||
print(f"{status} {test_name}")
|
||||
|
||||
print("-" * 60)
|
||||
print(f"总计: {passed}/{total} 通过")
|
||||
print("=" * 60)
|
||||
|
||||
return passed == total
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = asyncio.run(main())
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user