from typing import TypedDict
from pydantic import BaseModel
from langgraph.graph import StateGraph, START, END
from langchain.agents import create_agent
from langchain.tools import tool
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
class State(TypedDict):
question: str
rewritten_query: str
documents: list[str]
answer: str
# WNBA 知识库,包含阵容、比赛结果和球员统计数据
embeddings = OpenAIEmbeddings()
vector_store = InMemoryVectorStore(embeddings)
vector_store.add_texts([
# 阵容
"New York Liberty 2024 阵容: Breanna Stewart, Sabrina Ionescu, Jonquel Jones, Courtney Vandersloot.",
"Las Vegas Aces 2024 阵容: A'ja Wilson, Kelsey Plum, Jackie Young, Chelsea Gray.",
"Indiana Fever 2024 阵容: Caitlin Clark, Aliyah Boston, Kelsey Mitchell, NaLyssa Smith.",
# 比赛结果
"2024 WNBA 总决赛: New York Liberty 击败 Minnesota Lynx 3-2 赢得冠军。",
"2024年6月15日: Indiana Fever 85, Chicago Sky 79。Caitlin Clark 得到 23 分和 8 次助攻。",
"2024年8月20日: Las Vegas Aces 92, Phoenix Mercury 84。A'ja Wilson 得到 35 分。",
# 球员统计数据
"A'ja Wilson 2024 赛季统计数据: 26.9 PPG, 11.9 RPG, 2.6 BPG。赢得 MVP 奖项。",
"Caitlin Clark 2024 新秀统计数据: 19.2 PPG, 8.4 APG, 5.7 RPG。赢得年度最佳新秀。",
"Breanna Stewart 2024 统计数据: 20.4 PPG, 8.5 RPG, 3.5 APG。",
])
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
@tool
def get_latest_news(query: str) -> str:
"""获取最新的 WNBA 新闻和更新。"""
# 在此放置您的新闻 API
return "最新消息:WNBA 宣布了 2025 年扩大的季后赛赛制..."
agent = create_agent(
model="openai:gpt-4.1",
tools=[get_latest_news],
)
model = ChatOpenAI(model="gpt-4.1")
class RewrittenQuery(BaseModel):
query: str
def rewrite_query(state: State) -> dict:
"""重写用户查询以获得更好的检索。"""
system_prompt = """重写此查询以检索相关的 WNBA 信息。
知识库包含:球队阵容、带比分的比赛结果以及球员统计数据(PPG、RPG、APG)。
重点关注提到的特定球员姓名、球队名称或统计类别。"""
response = model.with_structured_output(RewrittenQuery).invoke([
{"role": "system", "content": system_prompt},
{"role": "user", "content": state["question"]}
])
return {"rewritten_query": response.query}
def retrieve(state: State) -> dict:
"""基于重写的查询检索文档。"""
docs = retriever.invoke(state["rewritten_query"])
return {"documents": [doc.page_content for doc in docs]}
def call_agent(state: State) -> dict:
"""使用检索到的上下文生成答案。"""
context = "\n\n".join(state["documents"])
prompt = f"Context:\n{context}\n\nQuestion: {state['question']}"
response = agent.invoke({"messages": [{"role": "user", "content": prompt}]})
return {"answer": response["messages"][-1].content_blocks}
workflow = (
StateGraph(State)
.add_node("rewrite", rewrite_query)
.add_node("retrieve", retrieve)
.add_node("agent", call_agent)
.add_edge(START, "rewrite")
.add_edge("rewrite", "retrieve")
.add_edge("retrieve", "agent")
.add_edge("agent", END)
.compile()
)
result = workflow.invoke({"question": "Who won the 2024 WNBA Championship?"})
print(result["answer"])