Skip to main content
Ray Serve 是一个用于构建在线推理 API 的可扩展模型服务库。Serve 特别适合系统组合,使您能够在 Python 代码中构建由多个链和业务逻辑组成的复杂推理服务。

本 notebook 的目标

本 notebook 展示了一个将 OpenAI 链部署到生产环境的简单示例。您可以扩展它来部署您自己的自托管模型,在那里您可以轻松定义运行模型所需的硬件资源(GPU 和 CPU)以高效地在生产中运行。在 Ray Serve 文档 中阅读更多关于可用选项(包括自动扩展)的信息。

设置 Ray Serve

使用 pip install ray[serve] 安装 ray。

通用框架

部署服务的通用框架如下:
# 0: 从 starlette 导入 ray serve 和 request
from ray import serve
from starlette.requests import Request


# 1: 定义 Ray Serve 部署。
@serve.deployment
class LLMServe:
    def __init__(self) -> None:
        # 所有初始化代码放在这里
        pass

    async def __call__(self, request: Request) -> str:
        # 您可以在这里解析请求
        # 并返回响应
        return "Hello World"


# 2: 将模型绑定到部署
deployment = LLMServe.bind()

# 3: 运行部署
serve.api.run(deployment)
# 关闭部署
serve.api.shutdown()

使用自定义提示部署 OpenAI 链的示例

此处 获取 OpenAI API 密钥。运行以下代码时,系统会提示您提供 API 密钥。
from langchain_classic.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI
from getpass import getpass

OPENAI_API_KEY = getpass()
@serve.deployment
class DeployLLM:
    def __init__(self):
        # 在这里初始化 LLM、模板和链
        llm = OpenAI(openai_api_key=OPENAI_API_KEY)
        template = "Question: {question}\n\nAnswer: Let's think step by step."
        prompt = PromptTemplate.from_template(template)
        self.chain = LLMChain(llm=llm, prompt=prompt)

    def _run_chain(self, text: str):
        return self.chain(text)

    async def __call__(self, request: Request):
        # 1. 解析请求
        text = request.query_params["text"]
        # 2. 运行链
        resp = self._run_chain(text)
        # 3. 返回响应
        return resp["text"]
现在我们可以绑定部署。
# 将模型绑定到部署
deployment = DeployLLM.bind()
当我们想运行部署时,可以分配端口号和主机。
# 示例端口号
PORT_NUMBER = 8282
# 运行部署
serve.api.run(deployment, port=PORT_NUMBER)
现在服务部署在 localhost:8282 端口,我们可以发送 post 请求来获取结果。
import requests

text = "What NFL team won the Super Bowl in the year Justin Beiber was born?"
response = requests.post(f"http://localhost:{PORT_NUMBER}/?text={text}")
print(response.content.decode())