Skip to main content
langchain-nvidia-ai-endpoints 包包含了由 NVIDIA AI Foundation Models 提供支持并通过 NVIDIA API Catalog 托管的聊天模型和嵌入模型的 LangChain 集成。 NVIDIA AI Foundation 模型是由社区和 NVIDIA 构建的模型,经过优化以在 NVIDIA 加速的基础设施上提供最佳性能。您可以使用该 API 查询 NVIDIA API Catalog 上可用的实时端点,以从 DGX 托管的云计算环境中快速获取结果;或者,您可以使用包含在 NVIDIA AI Enterprise 许可证中的 NVIDIA NIM 从 NVIDIA 的 API 目录下载模型。在本地运行模型的能力使您的企业能够完全掌控自定义内容,并全面控制知识产权(IP)和 AI 应用程序。 NIM 微服务按模型/模型系列打包为容器镜像,并通过 NVIDIA NGC Catalog 作为 NGC 容器镜像分发。本质上,NIM 微服务是提供用于在 AI 模型上运行推理的交互式 API 的容器。 本示例介绍如何使用 LangChain 通过 NVIDIAEmbeddings 类与支持 检索增强生成NVIDIA Retrieval QA Embedding Model 进行交互。 有关通过此 API 访问聊天模型的更多信息,请参阅 ChatNVIDIA 文档。

安装软件包

pip install -qU langchain-nvidia-ai-endpoints

访问 NVIDIA API Catalog

要访问 NVIDIA API Catalog,请执行以下步骤:
  1. NVIDIA API Catalog 上创建免费账户并登录。
  2. 单击您的个人资料图标,然后单击 API KeysAPI Keys 页面将出现。
  3. 单击 Generate API KeyGenerate API Key 窗口将出现。
  4. 单击 Generate Key。您应该会看到 API Key Granted,并且您的密钥将显示出来。
  5. 复制该密钥并将其保存为 NVIDIA_API_KEY
  6. 要验证您的密钥,请使用以下代码。
import getpass
import os

if os.environ.get("NVIDIA_API_KEY", "").startswith("nvapi-"):
    print("Valid NVIDIA_API_KEY already in environment. Delete to reset")
else:
    nvapi_key = getpass.getpass("NVAPI Key (starts with nvapi-): ")
    assert nvapi_key.startswith(
        "nvapi-"
    ), f"{nvapi_key[:5]}... is not a valid key"
    os.environ["NVIDIA_API_KEY"] = nvapi_key
现在,您可以使用您的密钥来访问 NVIDIA API Catalog 上的端点。

使用 API Catalog

在初始化嵌入模型时,您可以通过传递参数来选择模型,例如下面的 NV-Embed-QA,或者不传递任何参数以使用默认模型。
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings

embedder = NVIDIAEmbeddings(model="NV-Embed-QA")
该模型是一个经过微调的 E5-large 模型,支持预期的 Embeddings 方法,包括:
  • embed_query:为查询样本生成查询嵌入。
  • embed_documents:为希望搜索的文档列表生成段落嵌入。
  • aembed_query/aembed_documents:上述方法的异步版本。

使用 NVIDIA NIM 微服务进行自托管

当您准备好部署 AI 应用程序时,可以使用 NVIDIA NIM 对模型进行自托管。有关更多信息,请参阅 NVIDIA NIM Microservices 以下代码连接到本地托管的 NIM 微服务。
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank

# 连接到在 localhost:8000 运行的聊天 NIM,指定一个模型
llm = ChatNVIDIA(base_url="http://localhost:8000/v1", model="meta/llama3-8b-instruct")

# 连接到在 localhost:8080 运行的嵌入 NIM
embedder = NVIDIAEmbeddings(base_url="http://localhost:8080/v1")

# 连接到在 localhost:2016 运行的重排序 NIM
ranker = NVIDIARerank(base_url="http://localhost:2016/v1")

相似度

以下是对这些数据点的相似度进行的快速测试: 查询:
  • Komchatka 的天气怎么样?
  • 意大利以哪些食物闻名?
  • 我叫什么名字?我打赌你记不住……
  • 人生的意义到底是什么?
  • 人生的意义在于享受乐趣 :D
文档:
  • Komchatka 的天气寒冷,冬季漫长且严寒。
  • 意大利以意面、披萨、冰淇淋和浓缩咖啡闻名。
  • 我无法记住个人姓名,仅提供信息。
  • 人生的目的各不相同,通常被视为个人成就的实现。
  • 享受生活中的每一刻确实是一种美好的方式。

嵌入运行时

print("\nSequential Embedding: ")
q_embeddings = [
    embedder.embed_query("What's the weather like in Komchatka?"),
    embedder.embed_query("What kinds of food is Italy known for?"),
    embedder.embed_query("What's my name? I bet you don't remember..."),
    embedder.embed_query("What's the point of life anyways?"),
    embedder.embed_query("The point of life is to have fun :D"),
]
print("Shape:", (len(q_embeddings), len(q_embeddings[0])))

文档嵌入

print("\nBatch Document Embedding: ")
d_embeddings = embedder.embed_documents(
    [
        "Komchatka's weather is cold, with long, severe winters.",
        "Italy is famous for pasta, pizza, gelato, and espresso.",
        "I can't recall personal names, only provide information.",
        "Life's purpose varies, often seen as personal fulfillment.",
        "Enjoying life's moments is indeed a wonderful approach.",
    ]
)
print("Shape:", (len(d_embeddings), len(d_embeddings[0])))
现在我们已生成嵌入向量,可以对结果进行简单的相似度检查,以查看在检索任务中哪些文档会被触发为合理的答案:
pip install -qU  matplotlib scikit-learn
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# 计算 q_embeddings 和 d_embeddings 之间的相似度矩阵
cross_similarity_matrix = cosine_similarity(
    np.array(q_embeddings),
    np.array(d_embeddings),
)

# 绘制交叉相似度矩阵
plt.figure(figsize=(8, 6))
plt.imshow(cross_similarity_matrix, cmap="Greens", interpolation="nearest")
plt.colorbar()
plt.title("Cross-Similarity Matrix")
plt.xlabel("Query Embeddings")
plt.ylabel("Document Embeddings")
plt.grid(True)
plt.show()
提醒一下,发送到我们系统的查询和文档如下: 查询:
  • Komchatka 的天气怎么样?
  • 意大利以哪些食物闻名?
  • 我叫什么名字?我打赌你记不住……
  • 人生的意义到底是什么?
  • 人生的意义在于享受乐趣 :D
文档:
  • Komchatka 的天气寒冷,冬季漫长且严寒。
  • 意大利以意面、披萨、冰淇淋和浓缩咖啡闻名。
  • 我无法记住个人姓名,仅提供信息。
  • 人生的目的各不相同,通常被视为个人成就的实现。
  • 享受生活中的每一刻确实是一种美好的方式。

截断

嵌入模型通常具有固定的上下文窗口,用于确定可嵌入的最大输入 token 数量。此限制可能是硬性限制,等于模型的最大输入 token 长度;也可能是有效限制,超出该限制后嵌入的准确性会下降。 由于模型基于 token 运行,而应用程序通常处理文本,因此应用程序很难确保其输入保持在模型的 token 限制范围内。默认情况下,如果输入过大,将抛出异常。 为此,NVIDIA 的 NIM(API Catalog 或本地)提供了一个 truncate 参数,如果输入过大,该参数会在服务器端对输入进行截断。 truncate 参数有三个选项:
  • “NONE”:默认选项。如果输入过大,将抛出异常。
  • “START”:服务器从开头(左侧)截断输入,根据需要丢弃 token。
  • “END”:服务器从末尾(右侧)截断输入,根据需要丢弃 token。
long_text = "AI is amazing, amazing is " * 100
strict_embedder = NVIDIAEmbeddings()
try:
    strict_embedder.embed_query(long_text)
except Exception as e:
    print("Error:", e)
truncating_embedder = NVIDIAEmbeddings(truncate="END")
truncating_embedder.embed_query(long_text)[:5]

RAG 检索

以下是对 LangChain Expression Language Retrieval Cookbook entry 初始示例的重新利用,但在其 playground 环境中使用 AI Foundation Models 的 Mixtral 8x7B InstructNVIDIA Retrieval QA Embedding 模型执行。食谱中的后续示例也能按预期运行,我们鼓励您尝试这些选项。 提示: 我们建议使用 Mixtral 进行内部推理(即遵循指令进行数据提取、工具选择等),并使用 Llama-Chat 生成最终的单一“总结”响应(即基于历史记录和上下文生成一条适用于该用户的简单回复)。
pip install -qU  langchain faiss-cpu tiktoken langchain-community

from operator import itemgetter

from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_nvidia_ai_endpoints import ChatNVIDIA
vectorstore = FAISS.from_texts(
    ["harrison worked at kensho"],
    embedding=NVIDIAEmbeddings(model="NV-Embed-QA"),
)
retriever = vectorstore.as_retriever()

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Answer solely based on the following context:\n<Documents>\n{context}\n</Documents>",
        ),
        ("user", "{question}"),
    ]
)

model = ChatNVIDIA(model="ai-mixtral-8x7b-instruct")

chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

chain.invoke("where did harrison work?")
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Answer using information solely based on the following context:\n<Documents>\n{context}\n</Documents>"
            "\nSpeak only in the following language: {language}",
        ),
        ("user", "{question}"),
    ]
)

chain = (
    {
        "context": itemgetter("question") | retriever,
        "question": itemgetter("question"),
        "language": itemgetter("language"),
    }
    | prompt
    | model
    | StrOutputParser()
)

chain.invoke({"question": "where did harrison work", "language": "italian"})

相关主题