All checks were successful
Gitea Actions Demo / deploy (push) Successful in 13s
195 lines
6.9 KiB
Python
195 lines
6.9 KiB
Python
import json
|
|
from datetime import datetime, timedelta, timezone
|
|
import re
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from openai import OpenAI
|
|
|
|
from config.settings import settings
|
|
from llm import prompt as prompts
|
|
from utils.logger import logger
|
|
|
|
|
|
BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
MODEL = "deepseek-v3.2"
|
|
|
|
|
|
def _make_client() -> OpenAI:
|
|
return OpenAI(api_key=settings.DASHSCOPE_API_KEY, base_url=BASE_URL)
|
|
|
|
|
|
def _call_model(system_prompt: Optional[str], user_prompt: str, stream: bool = False, enable_search: bool = False) -> Any:
|
|
client = _make_client()
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
messages.append({"role": "user", "content": user_prompt})
|
|
|
|
# Non-streaming call for simplicity
|
|
resp = client.chat.completions.create(model=MODEL, messages=messages, stream=stream, extra_body={"enable_search": enable_search})
|
|
# When stream=False the SDK typically returns a full object; content location may vary.
|
|
# We'll try common access patterns.
|
|
try:
|
|
# OpenAI-compatible: resp.choices[0].message.content
|
|
return resp.choices[0].message.content
|
|
except Exception:
|
|
try:
|
|
# fallback: resp.choices[0].text
|
|
return resp.choices[0].text
|
|
except Exception:
|
|
# As last resort, return raw resp
|
|
return resp
|
|
|
|
|
|
def _extract_json(text: str) -> str:
|
|
"""Attempt to extract the first JSON object/array from text."""
|
|
if not isinstance(text, str):
|
|
raise ValueError("Expected text to be str")
|
|
# Find first '[' or '{'
|
|
start_idx = None
|
|
for i, ch in enumerate(text):
|
|
if ch in "[{":
|
|
start_idx = i
|
|
break
|
|
if start_idx is None:
|
|
raise ValueError("No JSON object/array found in text")
|
|
|
|
# Try to find a matching closing bracket by scanning and counting
|
|
stack = []
|
|
for j in range(start_idx, len(text)):
|
|
ch = text[j]
|
|
if ch in "{[":
|
|
stack.append(ch)
|
|
elif ch in "]}":
|
|
if not stack:
|
|
continue
|
|
opening = stack.pop()
|
|
if (opening == "{" and ch != "}") or (opening == "[" and ch != "]"):
|
|
# mismatched, continue
|
|
continue
|
|
if not stack:
|
|
return text[start_idx : j + 1]
|
|
|
|
# Fallback: try regex to capture last '}' or ']' occurrence
|
|
m = re.search(r"(\{.*\}|\[.*\])", text, re.S)
|
|
if m:
|
|
return m.group(1)
|
|
raise ValueError("Could not extract JSON from model output")
|
|
|
|
|
|
def _parse_json_safe(text: str) -> Any:
|
|
try:
|
|
return json.loads(text)
|
|
except Exception:
|
|
# try to extract JSON substring
|
|
jtext = _extract_json(text)
|
|
return json.loads(jtext)
|
|
|
|
|
|
def generate_topics(start_time: Optional[str] = None, end_time: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
"""Call prompt_a to get a list of candidate meme topics.
|
|
|
|
If start_time/end_time are provided (YYYY-MM-DD), they will be injected into the prompt
|
|
to limit the timeframe the model should scan.
|
|
|
|
If start_time/end_time are not provided, default to the last 7 days:
|
|
end_time = today (UTC, YYYY-MM-DD)
|
|
start_time = end_time - 7 days
|
|
|
|
Both parameters should be strings in YYYY-MM-DD format when provided.
|
|
"""
|
|
# compute defaults (UTC)
|
|
if end_time is None:
|
|
end_date = datetime.now(timezone.utc).date()
|
|
end_time = end_date.isoformat()
|
|
|
|
if start_time is None:
|
|
start_date = end_date - timedelta(days=7)
|
|
start_time = start_date.isoformat()
|
|
|
|
user_prompt = prompts.prompt_a
|
|
# If the prompt contains the literal placeholder, replace it; otherwise append a time line.
|
|
if "start_time ~ end_time" in user_prompt:
|
|
if start_time is None:
|
|
start_time = ""
|
|
if end_time is None:
|
|
end_time = ""
|
|
user_prompt = user_prompt.replace("start_time ~ end_time", f"{start_time} ~ {end_time}")
|
|
|
|
logger.debug(f"prompt for generate_topics:\n{user_prompt}")
|
|
|
|
content = _call_model(system_prompt=None, user_prompt=user_prompt, enable_search=True)
|
|
logger.debug(f"raw output from generate_topics:\n{content}")
|
|
if isinstance(content, (dict, list)):
|
|
return content
|
|
text = content if isinstance(content, str) else str(content)
|
|
data = _parse_json_safe(text)
|
|
if not isinstance(data, list):
|
|
raise ValueError("prompt_a did not return a JSON array")
|
|
logger.debug(f"result for generate_topics:\n{data}")
|
|
return data
|
|
|
|
|
|
def generate_bits(meme_name: str, research_text: str, prompt_bit: str = prompts.prompt_b) -> Dict[str, Any]:
|
|
user_prompt = prompt_bit + f"\n\nmeme_name: {meme_name}\nresearch:\n{research_text}\n"
|
|
content = _call_model(system_prompt=None, user_prompt=user_prompt)
|
|
text = content if isinstance(content, str) else str(content)
|
|
data = _parse_json_safe(text)
|
|
return data
|
|
|
|
def generate_bit(meme_name: str, research_text: str, prompt_bit: str) -> Dict[str, Any]:
|
|
user_prompt = prompt_bit + f"\n\nmeme_name: {meme_name}\nresearch:\n{research_text}\n"
|
|
content = _call_model(system_prompt=None, user_prompt=user_prompt)
|
|
text = content if isinstance(content, str) else str(content)
|
|
data = _parse_json_safe(text)
|
|
return data
|
|
|
|
|
|
def generate_script(meme_name: str, materials_text: str) -> Dict[str, Any]:
|
|
user_prompt = prompts.prompt_c + f"\n\nmeme_name: {meme_name}\nmaterials:\n{materials_text}\n"
|
|
content = _call_model(system_prompt=None, user_prompt=user_prompt)
|
|
text = content if isinstance(content, str) else str(content)
|
|
data = _parse_json_safe(text)
|
|
return data
|
|
|
|
|
|
def orchestrate_for_first_topic() -> Dict[str, Any]:
|
|
"""High-level orchestration: pick first topic, synthesize research, create bits and final script."""
|
|
topics = generate_topics()
|
|
if not topics:
|
|
raise RuntimeError("No topics returned")
|
|
|
|
top = topics[0]
|
|
meme = top.get("title") or top.get("name") or "未知梗"
|
|
|
|
# Build a concise research text from topic fields
|
|
parts = []
|
|
if "summary" in top:
|
|
parts.append(f"简介:{top['summary']}")
|
|
if "origin" in top:
|
|
parts.append(f"可能起源:{top['origin']}")
|
|
if "reach_estimate" in top:
|
|
parts.append(f"传播估计:{top['reach_estimate']}")
|
|
if "angles" in top:
|
|
parts.append("角度:" + "; ".join(top.get("angles", [])))
|
|
|
|
research_text = "\n".join(parts)
|
|
|
|
bits = generate_bits(meme, research_text)
|
|
|
|
# Combine materials: human-crafted research + selected bits
|
|
materials = research_text + "\n\n" + json.dumps(bits, ensure_ascii=False, indent=2)
|
|
|
|
script = generate_script(meme, materials)
|
|
|
|
return {"topic": top, "bits": bits, "script": script}
|
|
|
|
|
|
# if __name__ == "__main__":
|
|
# # quick sanity check when run as script (will call API if keys present)
|
|
# try:
|
|
# out = orchestrate_for_first_topic()
|
|
# print(json.dumps(out, ensure_ascii=False, indent=2))
|
|
# except Exception as e:
|
|
# print(f"Error during orchestration: {e}")
|