Skip to main content
本文将帮助您开始使用 SQL 数据库工具包。有关 SQLDatabaseToolkit 所有功能和配置的详细文档,请参阅 API 参考 SQLDatabaseToolkit 中的工具用于与 SQL 数据库进行交互。 常见应用场景是使智能体能够使用关系型数据库中的数据来回答问题,可能以迭代方式进行(例如从错误中恢复)。 ⚠️ 安全提示 ⚠️ 构建 SQL 数据库问答系统需要执行模型生成的 SQL 查询。这样做存在固有风险。请务必确保您的数据库连接权限始终尽可能针对链/智能体的需求进行最小化授权。这将在一定程度上降低(但不能消除)构建模型驱动系统的风险。

设置

要启用单个工具的自动追踪,请设置您的 LangSmith API 密钥:
os.environ["LANGSMITH_API_KEY"] = getpass.getpass("Enter your LangSmith API key: ")
os.environ["LANGSMITH_TRACING"] = "true"

安装

此工具包位于 langchain-community 包中:
pip install -qU  langchain-community
出于演示目的,我们将访问 LangChain Hub 中的提示词。我们还需要 langgraph 来演示工具包与智能体的配合使用,但这并非使用工具包的必要条件。
pip install -qU langchainhub langgraph

实例化

SQLDatabaseToolkit 工具包需要: 以下将使用这些对象实例化工具包。首先创建一个数据库对象。 本指南使用基于这些说明的示例 Chinook 数据库。 以下将使用 requests 库拉取 .sql 文件并创建内存 SQLite 数据库。请注意,这种方法轻量级,但是临时性的且非线程安全的。如果您希望,可以按照说明将文件本地保存为 Chinook.db,并通过 db = SQLDatabase.from_uri("sqlite:///Chinook.db") 实例化数据库。
import sqlite3

import requests
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool


def get_engine_for_chinook_db():
    """Pull sql file, populate in-memory database, and create engine."""
    url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
    response = requests.get(url)
    sql_script = response.text

    connection = sqlite3.connect(":memory:", check_same_thread=False)
    connection.executescript(sql_script)
    return create_engine(
        "sqlite://",
        creator=lambda: connection,
        poolclass=StaticPool,
        connect_args={"check_same_thread": False},
    )


engine = get_engine_for_chinook_db()

db = SQLDatabase(engine)
我们还需要一个 LLM 或聊天模型:
# | output: false
# | echo: false

from langchain_openai import ChatOpenAI

llm = ChatOpenAI(temperature=0)
现在可以实例化工具包:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

工具

查看可用工具:
toolkit.get_tools()
[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>),
 QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>, llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x10742d720>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x10742f7f0>, root_client=<openai.OpenAI object at 0x103d5fac0>, root_async_client=<openai.AsyncOpenAI object at 0x10742d780>, temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x10742d720>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x10742f7f0>, root_client=<openai.OpenAI object at 0x103d5fac0>, root_async_client=<openai.AsyncOpenAI object at 0x10742d780>, temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), output_parser=StrOutputParser(), llm_kwargs={}))]
您可以直接使用单个工具:
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
    QuerySQLDatabaseTool,
)

在 Agent 中使用

参照 SQL 问答教程,以下我们为一个简单的问答智能体配备工具包中的工具。首先拉取相关提示词模板并填充所需参数:
from langchain_classic import hub

prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")

assert len(prompt_template.messages) == 1
print(prompt_template.input_variables)
['dialect', 'top_k']
system_message = prompt_template.format(dialect="SQLite", top_k=5)
然后实例化智能体:
from langchain.agents import create_agent

agent = create_agent(llm, toolkit.get_tools(), system_prompt=system_message)
并向其发出查询:
example_query = "Which country's customers spent the most?"

events = agent.stream(
    {"messages": [("user", example_query)]},
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()
================================ Human Message =================================

Which country's customers spent the most?
================================== Ai Message ==================================
Tool Calls:
  sql_db_list_tables (call_EBPjyfzqXzFutDn8BklYACLj)
 Call ID: call_EBPjyfzqXzFutDn8BklYACLj
  Args:
================================= Tool Message =================================
Name: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:
  sql_db_schema (call_kGcnKpxRVFIY8dPjYIJbRoVU)
 Call ID: call_kGcnKpxRVFIY8dPjYIJbRoVU
  Args:
    table_names: Customer, Invoice, InvoiceLine
================================= Tool Message =================================
Name: sql_db_schema


CREATE TABLE "Customer" (
 "CustomerId" INTEGER NOT NULL,
 "FirstName" NVARCHAR(40) NOT NULL,
 "LastName" NVARCHAR(20) NOT NULL,
 "Company" NVARCHAR(80),
 "Address" NVARCHAR(70),
 "City" NVARCHAR(40),
 "State" NVARCHAR(40),
 "Country" NVARCHAR(40),
 "PostalCode" NVARCHAR(10),
 "Phone" NVARCHAR(24),
 "Fax" NVARCHAR(24),
 "Email" NVARCHAR(60) NOT NULL,
 "SupportRepId" INTEGER,
 PRIMARY KEY ("CustomerId"),
 FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3
*/


CREATE TABLE "Invoice" (
 "InvoiceId" INTEGER NOT NULL,
 "CustomerId" INTEGER NOT NULL,
 "InvoiceDate" DATETIME NOT NULL,
 "BillingAddress" NVARCHAR(70),
 "BillingCity" NVARCHAR(40),
 "BillingState" NVARCHAR(40),
 "BillingCountry" NVARCHAR(40),
 "BillingPostalCode" NVARCHAR(10),
 "Total" NUMERIC(10, 2) NOT NULL,
 PRIMARY KEY ("InvoiceId"),
 FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)

/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2021-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2021-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2021-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/


CREATE TABLE "InvoiceLine" (
 "InvoiceLineId" INTEGER NOT NULL,
 "InvoiceId" INTEGER NOT NULL,
 "TrackId" INTEGER NOT NULL,
 "UnitPrice" NUMERIC(10, 2) NOT NULL,
 "Quantity" INTEGER NOT NULL,
 PRIMARY KEY ("InvoiceLineId"),
 FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"),
 FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")
)

/*
3 rows from InvoiceLine table:
InvoiceLineId InvoiceId TrackId UnitPrice Quantity
1 1 2 0.99 1
2 1 4 0.99 1
3 2 6 0.99 1
*/
================================== Ai Message ==================================
Tool Calls:
  sql_db_query (call_cTfI7OrY64FzJaDd49ILFWw7)
 Call ID: call_cTfI7OrY64FzJaDd49ILFWw7
  Args:
    query: SELECT c.Country, SUM(i.Total) AS TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1
================================= Tool Message =================================
Name: sql_db_query

[('USA', 523.06)]
================================== Ai Message ==================================

Customers from the USA spent the most, with a total amount spent of $523.06.
我们还可以观察智能体从错误中恢复:
example_query = "Who are the top 3 best selling artists?"

events = agent.stream(
    {"messages": [("user", example_query)]},
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()
================================ Human Message =================================

Who are the top 3 best selling artists?
================================== Ai Message ==================================
Tool Calls:
  sql_db_query (call_xAkvYiRFM7nCMKXsDNvk1OMx)
 Call ID: call_xAkvYiRFM7nCMKXsDNvk1OMx
  Args:
    query: SELECT artist_name, SUM(quantity) AS total_sold FROM sales GROUP BY artist_name ORDER BY total_sold DESC LIMIT 3
================================= Tool Message =================================
Name: sql_db_query

Error: (sqlite3.OperationalError) no such table: sales
[SQL: SELECT artist_name, SUM(quantity) AS total_sold FROM sales GROUP BY artist_name ORDER BY total_sold DESC LIMIT 3]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
================================== Ai Message ==================================
Tool Calls:
  sql_db_list_tables (call_K4Zvbowsq7XPgGFepbvc5G7i)
 Call ID: call_K4Zvbowsq7XPgGFepbvc5G7i
  Args:
================================= Tool Message =================================
Name: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:
  sql_db_schema (call_tUztueSK7VO2klZ99xT4ZVhM)
 Call ID: call_tUztueSK7VO2klZ99xT4ZVhM
  Args:
    table_names: Artist, Album, InvoiceLine
================================= Tool Message =================================
Name: sql_db_schema


CREATE TABLE "Album" (
 "AlbumId" INTEGER NOT NULL,
 "Title" NVARCHAR(160) NOT NULL,
 "ArtistId" INTEGER NOT NULL,
 PRIMARY KEY ("AlbumId"),
 FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId Title ArtistId
1 For Those About To Rock We Salute You 1
2 Balls to the Wall 2
3 Restless and Wild 2
*/


CREATE TABLE "Artist" (
 "ArtistId" INTEGER NOT NULL,
 "Name" NVARCHAR(120),
 PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId Name
1 AC/DC
2 Accept
3 Aerosmith
*/


CREATE TABLE "InvoiceLine" (
 "InvoiceLineId" INTEGER NOT NULL,
 "InvoiceId" INTEGER NOT NULL,
 "TrackId" INTEGER NOT NULL,
 "UnitPrice" NUMERIC(10, 2) NOT NULL,
 "Quantity" INTEGER NOT NULL,
 PRIMARY KEY ("InvoiceLineId"),
 FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"),
 FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")
)

/*
3 rows from InvoiceLine table:
InvoiceLineId InvoiceId TrackId UnitPrice Quantity
1 1 2 0.99 1
2 1 4 0.99 1
3 2 6 0.99 1
*/
================================== Ai Message ==================================
Tool Calls:
  sql_db_query (call_tVtLQIRPmCM6pukgpHFfq86A)
 Call ID: call_tVtLQIRPmCM6pukgpHFfq86A
  Args:
    query: SELECT Artist.Name AS artist_name, SUM(InvoiceLine.Quantity) AS total_sold FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId GROUP BY Artist.Name ORDER BY total_sold DESC LIMIT 3
================================= Tool Message =================================
Name: sql_db_query

[('Iron Maiden', 140), ('U2', 107), ('Metallica', 91)]
================================== Ai Message ==================================

The top 3 best selling artists are:
1. Iron Maiden - 140 units sold
2. U2 - 107 units sold
3. Metallica - 91 units sold

特定功能

SQLDatabaseToolkit 实现了 .get_context 方法,便于在提示词或其他上下文中使用。 ⚠️ 免责声明 ⚠️:智能体可能会生成插入/更新/删除查询。若不希望发生此类操作,请使用自定义提示词或创建没有写入权限的 SQL 用户。 最终用户可能通过提问”运行尽可能大的查询”等方式使您的 SQL 数据库超载。生成的查询可能如下所示:
SELECT * FROM "public"."users"
    JOIN "public"."user_permissions" ON "public"."users".id = "public"."user_permissions".user_id
    JOIN "public"."projects" ON "public"."users".id = "public"."projects".user_id
    JOIN "public"."events" ON "public"."projects".id = "public"."events".project_id;
对于事务型 SQL 数据库,如果上述某张表包含数百万行数据,该查询可能会对使用同一数据库的其他应用程序造成影响。 大多数面向数据仓库的数据库支持用户级配额,以限制资源使用。

API 参考

有关 SQLDatabaseToolkit 所有功能和配置的详细文档,请参阅 API 参考