Files
meme/llm/generate_podcast.py
konjacpotato 6772699cfe
Some checks failed
Gitea Actions Demo / deploy (push) Failing after 2s
commit code
2025-12-29 19:34:39 +08:00

195 lines
6.9 KiB
Python

import json
from datetime import datetime, timedelta, timezone
import re
from typing import Any, Dict, List, Optional
from openai import OpenAI
from config.settings import settings
from llm import prompt as prompts
from utils.logger import logger
BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
MODEL = "deepseek-v3.2"
def _make_client() -> OpenAI:
return OpenAI(api_key=settings.DASHSCOPE_API_KEY, base_url=BASE_URL)
def _call_model(system_prompt: Optional[str], user_prompt: str, stream: bool = False, enable_search: bool = False) -> Any:
client = _make_client()
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_prompt})
# Non-streaming call for simplicity
resp = client.chat.completions.create(model=MODEL, messages=messages, stream=stream, extra_body={"enable_search": enable_search})
# When stream=False the SDK typically returns a full object; content location may vary.
# We'll try common access patterns.
try:
# OpenAI-compatible: resp.choices[0].message.content
return resp.choices[0].message.content
except Exception:
try:
# fallback: resp.choices[0].text
return resp.choices[0].text
except Exception:
# As last resort, return raw resp
return resp
def _extract_json(text: str) -> str:
"""Attempt to extract the first JSON object/array from text."""
if not isinstance(text, str):
raise ValueError("Expected text to be str")
# Find first '[' or '{'
start_idx = None
for i, ch in enumerate(text):
if ch in "[{":
start_idx = i
break
if start_idx is None:
raise ValueError("No JSON object/array found in text")
# Try to find a matching closing bracket by scanning and counting
stack = []
for j in range(start_idx, len(text)):
ch = text[j]
if ch in "{[":
stack.append(ch)
elif ch in "]}":
if not stack:
continue
opening = stack.pop()
if (opening == "{" and ch != "}") or (opening == "[" and ch != "]"):
# mismatched, continue
continue
if not stack:
return text[start_idx : j + 1]
# Fallback: try regex to capture last '}' or ']' occurrence
m = re.search(r"(\{.*\}|\[.*\])", text, re.S)
if m:
return m.group(1)
raise ValueError("Could not extract JSON from model output")
def _parse_json_safe(text: str) -> Any:
try:
return json.loads(text)
except Exception:
# try to extract JSON substring
jtext = _extract_json(text)
return json.loads(jtext)
def generate_topics(start_time: Optional[str] = None, end_time: Optional[str] = None) -> List[Dict[str, Any]]:
"""Call prompt_a to get a list of candidate meme topics.
If start_time/end_time are provided (YYYY-MM-DD), they will be injected into the prompt
to limit the timeframe the model should scan.
If start_time/end_time are not provided, default to the last 7 days:
end_time = today (UTC, YYYY-MM-DD)
start_time = end_time - 7 days
Both parameters should be strings in YYYY-MM-DD format when provided.
"""
# compute defaults (UTC)
if end_time is None:
end_date = datetime.now(timezone.utc).date()
end_time = end_date.isoformat()
if start_time is None:
start_date = end_date - timedelta(days=7)
start_time = start_date.isoformat()
user_prompt = prompts.prompt_a
# If the prompt contains the literal placeholder, replace it; otherwise append a time line.
if "start_time ~ end_time" in user_prompt:
if start_time is None:
start_time = ""
if end_time is None:
end_time = ""
user_prompt = user_prompt.replace("start_time ~ end_time", f"{start_time} ~ {end_time}")
logger.debug(f"prompt for generate_topics:\n{user_prompt}")
content = _call_model(system_prompt=None, user_prompt=user_prompt, enable_search=True)
logger.debug(f"raw output from generate_topics:\n{content}")
if isinstance(content, (dict, list)):
return content
text = content if isinstance(content, str) else str(content)
data = _parse_json_safe(text)
if not isinstance(data, list):
raise ValueError("prompt_a did not return a JSON array")
logger.debug(f"result for generate_topics:\n{data}")
return data
def generate_bits(meme_name: str, research_text: str, prompt_bit: str = prompts.prompt_b) -> Dict[str, Any]:
user_prompt = prompt_bit + f"\n\nmeme_name: {meme_name}\nresearch:\n{research_text}\n"
content = _call_model(system_prompt=None, user_prompt=user_prompt)
text = content if isinstance(content, str) else str(content)
data = _parse_json_safe(text)
return data
def generate_bit(meme_name: str, research_text: str, prompt_bit: str) -> Dict[str, Any]:
user_prompt = prompt_bit + f"\n\nmeme_name: {meme_name}\nresearch:\n{research_text}\n"
content = _call_model(system_prompt=None, user_prompt=user_prompt)
text = content if isinstance(content, str) else str(content)
data = _parse_json_safe(text)
return data
def generate_script(meme_name: str, materials_text: str) -> Dict[str, Any]:
user_prompt = prompts.prompt_c + f"\n\nmeme_name: {meme_name}\nmaterials:\n{materials_text}\n"
content = _call_model(system_prompt=None, user_prompt=user_prompt)
text = content if isinstance(content, str) else str(content)
data = _parse_json_safe(text)
return data
def orchestrate_for_first_topic() -> Dict[str, Any]:
"""High-level orchestration: pick first topic, synthesize research, create bits and final script."""
topics = generate_topics()
if not topics:
raise RuntimeError("No topics returned")
top = topics[0]
meme = top.get("title") or top.get("name") or "未知梗"
# Build a concise research text from topic fields
parts = []
if "summary" in top:
parts.append(f"简介:{top['summary']}")
if "origin" in top:
parts.append(f"可能起源:{top['origin']}")
if "reach_estimate" in top:
parts.append(f"传播估计:{top['reach_estimate']}")
if "angles" in top:
parts.append("角度:" + "; ".join(top.get("angles", [])))
research_text = "\n".join(parts)
bits = generate_bits(meme, research_text)
# Combine materials: human-crafted research + selected bits
materials = research_text + "\n\n" + json.dumps(bits, ensure_ascii=False, indent=2)
script = generate_script(meme, materials)
return {"topic": top, "bits": bits, "script": script}
if __name__ == "__main__":
# quick sanity check when run as script (will call API if keys present)
try:
out = orchestrate_for_first_topic()
print(json.dumps(out, ensure_ascii=False, indent=2))
except Exception as e:
print(f"Error during orchestration: {e}")