Skip to main content
langchain-nvidia-ai-endpoints 包含由 NVIDIA AI Foundation Models 驱动、托管于 NVIDIA API Catalog 的聊天模型和嵌入模型的 LangChain 集成。 NVIDIA AI Foundation 模型是由社区和 NVIDIA 共同构建的模型,针对 NVIDIA 加速基础设施进行了优化以提供最佳性能。您可以使用 API 查询 NVIDIA API Catalog 上的在线端点,从 DGX 托管的云计算环境中快速获取结果;也可以通过 NVIDIA NIM(包含在 NVIDIA AI Enterprise 许可证中)从 NVIDIA API Catalog 下载模型。在本地运行模型的能力使企业能够掌控自定义内容,并完全控制其知识产权和 AI 应用。 NIM 微服务按模型/模型系列打包为容器镜像,通过 NVIDIA NGC Catalog 以 NGC 容器镜像的形式分发。NIM 微服务的核心是提供交互式 API 的容器,用于在 AI 模型上运行推理。 本示例展示如何使用 LangChain 通过 NVIDIAEmbeddings 类与支持的 NVIDIA Retrieval QA Embedding Model 进行交互,以实现检索增强生成 有关通过此 API 访问聊天模型的更多信息,请参阅 ChatNVIDIA 文档。

安装包

pip install -qU langchain-nvidia-ai-endpoints

访问 NVIDIA API Catalog

要获得 NVIDIA API Catalog 的访问权限,请执行以下步骤:
  1. NVIDIA API Catalog 上创建免费账户并登录。
  2. 点击您的头像图标,然后点击 API Keys,进入 API Keys 页面。
  3. 点击 Generate API Key,弹出 Generate API Key 窗口。
  4. 点击 Generate Key,您将看到 API Key Granted,您的密钥随即显示。
  5. 复制并保存密钥为 NVIDIA_API_KEY
  6. 使用以下代码验证您的密钥。
import getpass
import os

if os.environ.get("NVIDIA_API_KEY", "").startswith("nvapi-"):
    print("Valid NVIDIA_API_KEY already in environment. Delete to reset")
else:
    nvapi_key = getpass.getpass("NVAPI Key (starts with nvapi-): ")
    assert nvapi_key.startswith(
        "nvapi-"
    ), f"{nvapi_key[:5]}... is not a valid key"
    os.environ["NVIDIA_API_KEY"] = nvapi_key
现在您可以使用密钥访问 NVIDIA API Catalog 上的端点。

使用 API Catalog

初始化嵌入模型时,您可以通过传入模型名称(如下方的 NV-Embed-QA)来选择模型,或者不传入任何参数以使用默认模型。
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings

embedder = NVIDIAEmbeddings(model="NV-Embed-QA")
该模型是经过微调的 E5-large 模型,支持预期的 Embeddings 方法,包括:
  • embed_query:为查询样本生成查询嵌入向量。
  • embed_documents:为您希望搜索的文档列表生成段落嵌入向量。
  • aembed_query/aembed_documents:上述方法的异步版本。

使用 NVIDIA NIM 微服务自托管

当您准备好部署 AI 应用时,可以使用 NVIDIA NIM 自托管模型。更多信息请参阅 NVIDIA NIM Microservices 以下代码连接到本地托管的 NIM 微服务。
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank

# connect to a chat NIM running at localhost:8000, specifying a model
llm = ChatNVIDIA(base_url="http://localhost:8000/v1", model="meta/llama3-8b-instruct")

# connect to an embedding NIM running at localhost:8080
embedder = NVIDIAEmbeddings(base_url="http://localhost:8080/v1")

# connect to a reranking NIM running at localhost:2016
ranker = NVIDIARerank(base_url="http://localhost:2016/v1")

相似性测试

以下是对这些数据点进行相似性的快速测试: 查询:
  • 堪察加的天气怎么样?
  • 意大利以哪些食物著称?
  • 我叫什么名字?我打赌你不记得了…
  • 生命的意义是什么?
  • 生命的意义就是享受乐趣 :D
文档:
  • 堪察加天气寒冷,冬季漫长而严酷。
  • 意大利以意面、披萨、冰淇淋和浓缩咖啡闻名。
  • 我无法回忆个人姓名,只能提供信息。
  • 生命的目的因人而异,通常被视为个人实现。
  • 享受生命中的每一刻确实是一种美妙的方式。

嵌入运行时

print("\nSequential Embedding: ")
q_embeddings = [
    embedder.embed_query("What's the weather like in Komchatka?"),
    embedder.embed_query("What kinds of food is Italy known for?"),
    embedder.embed_query("What's my name? I bet you don't remember..."),
    embedder.embed_query("What's the point of life anyways?"),
    embedder.embed_query("The point of life is to have fun :D"),
]
print("Shape:", (len(q_embeddings), len(q_embeddings[0])))

文档嵌入

print("\nBatch Document Embedding: ")
d_embeddings = embedder.embed_documents(
    [
        "Komchatka's weather is cold, with long, severe winters.",
        "Italy is famous for pasta, pizza, gelato, and espresso.",
        "I can't recall personal names, only provide information.",
        "Life's purpose varies, often seen as personal fulfillment.",
        "Enjoying life's moments is indeed a wonderful approach.",
    ]
)
print("Shape:", (len(d_embeddings), len(d_embeddings[0])))
生成嵌入向量后,我们可以对结果进行简单的相似性检验,以查看哪些文档在检索任务中会被触发为合理答案:
pip install -qU  matplotlib scikit-learn
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Compute the similarity matrix between q_embeddings and d_embeddings
cross_similarity_matrix = cosine_similarity(
    np.array(q_embeddings),
    np.array(d_embeddings),
)

# Plotting the cross-similarity matrix
plt.figure(figsize=(8, 6))
plt.imshow(cross_similarity_matrix, cmap="Greens", interpolation="nearest")
plt.colorbar()
plt.title("Cross-Similarity Matrix")
plt.xlabel("Query Embeddings")
plt.ylabel("Document Embeddings")
plt.grid(True)
plt.show()
提醒一下,发送到系统的查询和文档如下: 查询:
  • 堪察加的天气怎么样?
  • 意大利以哪些食物著称?
  • 我叫什么名字?我打赌你不记得了…
  • 生命的意义是什么?
  • 生命的意义就是享受乐趣 :D
文档:
  • 堪察加天气寒冷,冬季漫长而严酷。
  • 意大利以意面、披萨、冰淇淋和浓缩咖啡闻名。
  • 我无法回忆个人姓名,只能提供信息。
  • 生命的目的因人而异,通常被视为个人实现。
  • 享受生命中的每一刻确实是一种美妙的方式。

截断

嵌入模型通常有固定的上下文窗口,决定了可嵌入的最大输入 token 数。该限制可能是硬限制(等于模型的最大输入 token 长度),也可能是有效限制(超出此限制后嵌入精度会下降)。 由于模型基于 token 运作而应用通常处理文本,因此确保输入在模型 token 限制内可能比较困难。默认情况下,如果输入过大,将抛出异常。 为此,NVIDIA 的 NIM(API Catalog 或本地)提供了 truncate 参数,在输入过大时在服务端进行截断。 truncate 参数有三个选项:
  • “NONE”:默认选项。输入过大时抛出异常。
  • “START”:服务端从开头(左侧)截断输入,丢弃必要的 token。
  • “END”:服务端从末尾(右侧)截断输入,丢弃必要的 token。
long_text = "AI is amazing, amazing is " * 100
strict_embedder = NVIDIAEmbeddings()
try:
    strict_embedder.embed_query(long_text)
except Exception as e:
    print("Error:", e)
truncating_embedder = NVIDIAEmbeddings(truncate="END")
truncating_embedder.embed_query(long_text)[:5]

RAG 检索

以下是对 LangChain 表达式语言检索 Cookbook 条目 初始示例的改编,使用 AI Foundation Models 中的 Mixtral 8x7B InstructNVIDIA Retrieval QA Embedding 模型(在其 playground 环境中提供)执行。Cookbook 中的后续示例也可按预期运行,我们鼓励您探索这些选项。 提示: 我们建议将 Mixtral 用于内部推理(即数据提取的指令遵循、工具选择等),将 Llama-Chat 用于最终的”根据历史和上下文为该用户生成简洁回应”的单次总结性回答。
pip install -qU  langchain faiss-cpu tiktoken langchain-community

from operator import itemgetter

from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_nvidia_ai_endpoints import ChatNVIDIA
vectorstore = FAISS.from_texts(
    ["harrison worked at kensho"],
    embedding=NVIDIAEmbeddings(model="NV-Embed-QA"),
)
retriever = vectorstore.as_retriever()

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Answer solely based on the following context:\n<Documents>\n{context}\n</Documents>",
        ),
        ("user", "{question}"),
    ]
)

model = ChatNVIDIA(model="ai-mixtral-8x7b-instruct")

chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

chain.invoke("where did harrison work?")
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Answer using information solely based on the following context:\n<Documents>\n{context}\n</Documents>"
            "\nSpeak only in the following language: {language}",
        ),
        ("user", "{question}"),
    ]
)

chain = (
    {
        "context": itemgetter("question") | retriever,
        "question": itemgetter("question"),
        "language": itemgetter("language"),
    }
    | prompt
    | model
    | StrOutputParser()
)

chain.invoke({"question": "where did harrison work", "language": "italian"})

相关主题