mirror of https://github.com/hwchase17/langchain
LangServe (#11046)
Adds LangServe package * Integrate Runnables with Fast API creating Server and a RemoteRunnable client * Support multiple runnables for a given server * Support sync/async/batch/abatch/stream/astream/astream_log on the client side (using async implementations on server) * Adds validation using annotations (relying on pydantic under the hood) -- this still has some rough edges -- e.g., open api docs do NOT generate correctly at the moment * Uses pydantic v1 namespace Known issues: type translation code doesn't handle a lot of types (e.g., TypedDicts) --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>pull/11161/head
parent
77ce9ed6f1
commit
b05bb9e136
@ -0,0 +1,82 @@
|
||||
---
|
||||
name: libs/langserve CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/actions/poetry_setup/action.yml'
|
||||
- '.github/tools/**'
|
||||
- '.github/workflows/_lint.yml'
|
||||
- '.github/workflows/_test.yml'
|
||||
- '.github/workflows/langserve_ci.yml'
|
||||
- 'libs/langserve/**'
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
# If another push to the same PR or branch happens while this workflow is still running,
|
||||
# cancel the earlier run in favor of the next run.
|
||||
#
|
||||
# There's no point in testing an outdated version of the code. GitHub only allows
|
||||
# a limited number of job runners to be active at the same time, so it's better to cancel
|
||||
# pointless jobs early so that more useful jobs can run sooner.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.5.1"
|
||||
WORKDIR: "libs/langserve"
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
uses:
|
||||
./.github/workflows/_lint.yml
|
||||
with:
|
||||
working-directory: libs/langserve
|
||||
secrets: inherit
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ${{ env.WORKDIR }}
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
name: Python ${{ matrix.python-version }} extended tests
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: "./.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: libs/langserve
|
||||
cache-key: langserve-all
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
echo "Running extended tests, installing dependencies with poetry..."
|
||||
poetry install --with test,lint --extras all
|
||||
|
||||
- name: Run tests
|
||||
run: make test
|
||||
|
||||
- name: Ensure the tests did not create any additional files
|
||||
shell: bash
|
||||
run: |
|
||||
set -eu
|
||||
|
||||
STATUS="$(git status)"
|
||||
echo "$STATUS"
|
||||
|
||||
# grep will exit non-zero if the target message isn't found,
|
||||
# and `set -e` above will cause the step to fail.
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -0,0 +1,54 @@
|
||||
.PHONY: all lint format test help
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
######################
|
||||
# TESTING AND COVERAGE
|
||||
######################
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langserve --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
|
||||
lint lint_diff:
|
||||
poetry run ruff .
|
||||
poetry run black $(PYTHON_FILES) --check
|
||||
|
||||
format format_diff:
|
||||
poetry run black $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '===================='
|
||||
@echo '-- LINTING --'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'spell_check - run codespell on the project'
|
||||
@echo 'spell_fix - run codespell on the project and fix the errors'
|
||||
@echo '-- TESTS --'
|
||||
@echo 'coverage - run unit tests and generate coverage report'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
||||
@echo '-- DOCUMENTATION tasks are from the top-level Makefile --'
|
@ -0,0 +1,124 @@
|
||||
# LangServe 🦜️🔗
|
||||
|
||||
## Overview
|
||||
|
||||
`LangServe` is a library that allows developers to host their Langchain runnables /
|
||||
call into them remotely from a runnable interface.
|
||||
|
||||
## Examples
|
||||
|
||||
For more examples, see the [examples](./examples) directory.
|
||||
|
||||
### Server
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python
|
||||
from fastapi import FastAPI
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.chat_models import ChatAnthropic, ChatOpenAI
|
||||
from langserve import add_routes
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="LangChain Server",
|
||||
version="1.0",
|
||||
description="A simple api server using Langchain's Runnable interfaces",
|
||||
)
|
||||
|
||||
|
||||
# Serve Open AI and Anthropic models
|
||||
LLMInput = Union[List[Union[SystemMessage, HumanMessage, str]], str]
|
||||
|
||||
add_routes(
|
||||
app,
|
||||
ChatOpenAI(),
|
||||
path="/openai",
|
||||
input_type=LLMInput,
|
||||
config_keys=[],
|
||||
)
|
||||
add_routes(
|
||||
app,
|
||||
ChatAnthropic(),
|
||||
path="/anthropic",
|
||||
input_type=LLMInput,
|
||||
config_keys=[],
|
||||
)
|
||||
|
||||
# Serve a joke chain
|
||||
class ChainInput(TypedDict):
|
||||
"""The input to the chain."""
|
||||
|
||||
topic: str
|
||||
"""The topic of the joke."""
|
||||
|
||||
model = ChatAnthropic()
|
||||
prompt = ChatPromptTemplate.from_template("tell me a joke about {topic}")
|
||||
add_routes(app, prompt | model, path="/chain", input_type=ChainInput)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="localhost", port=8000)
|
||||
```
|
||||
|
||||
|
||||
### Client
|
||||
|
||||
```python
|
||||
|
||||
from langchain.schema import SystemMessage, HumanMessage
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.runnable import RunnableMap
|
||||
from langserve import RemoteRunnable
|
||||
|
||||
openai = RemoteRunnable("http://localhost:8000/openai/")
|
||||
anthropic = RemoteRunnable("http://localhost:8000/anthropic/")
|
||||
joke_chain = RemoteRunnable("http://localhost:8000/chain/")
|
||||
|
||||
joke_chain.invoke({"topic": "parrots"})
|
||||
|
||||
# or async
|
||||
await joke_chain.ainvoke({"topic": "parrots"})
|
||||
|
||||
prompt = [
|
||||
SystemMessage(content='Act like either a cat or a parrot.'),
|
||||
HumanMessage(content='Hello!')
|
||||
]
|
||||
|
||||
# Supports astream
|
||||
async for msg in anthropic.astream(prompt):
|
||||
print(msg, end="", flush=True)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("system", "Tell me a long story about {topic}")]
|
||||
)
|
||||
|
||||
# Can define custom chains
|
||||
chain = prompt | RunnableMap({
|
||||
"openai": openai,
|
||||
"anthropic": anthropic,
|
||||
})
|
||||
|
||||
chain.batch([{ "topic": "parrots" }, { "topic": "cats" }])
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# pip install langserve[all] -- has not been published to pypi yet
|
||||
```
|
||||
|
||||
or use `client` extra for client code, and `server` extra for server code.
|
||||
|
||||
## Features
|
||||
|
||||
- Deploy runnables with FastAPI
|
||||
- Client can use remote runnables almost as if they were local
|
||||
- Supports async
|
||||
- Supports batch
|
||||
- Supports stream
|
||||
|
||||
### Limitations
|
||||
|
||||
- Chain callbacks cannot be passed from the client to the server
|
@ -0,0 +1,4 @@
|
||||
from .client import RemoteRunnable
|
||||
from .server import add_routes
|
||||
|
||||
__all__ = ["RemoteRunnable", "add_routes"]
|
@ -0,0 +1,163 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Client\n",
|
||||
"\n",
|
||||
"Demo of client interacting with the simple chain server, which deploys a chain that tells jokes about a particular topic."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts.chat import (\n",
|
||||
" HumanMessagePromptTemplate,\n",
|
||||
" SystemMessagePromptTemplate,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langserve import RemoteRunnable\n",
|
||||
"\n",
|
||||
"remote_runnable = RemoteRunnable(\"http://localhost:8000/\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Remote runnable has the same interface as local runnables"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = await remote_runnable.ainvoke({\"topic\": \"sports\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The client can also execute langchain code synchronously, and pass in configs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[AIMessage(content='Why did the football coach go to the bank?\\n\\nBecause he wanted to get his quarterback!', additional_kwargs={}, example=False),\n",
|
||||
" AIMessage(content='Why did the car bring a sweater to the race?\\n\\nBecause it wanted to have a \"car-digan\" finish!', additional_kwargs={}, example=False)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.schema.runnable.config import RunnableConfig\n",
|
||||
"\n",
|
||||
"remote_runnable.batch([{\"topic\": \"sports\"}, {\"topic\": \"cars\"}])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The server supports streaming (using HTTP server-side events), which can help interact with long responses in real time"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Ah, indulge me in this lighthearted endeavor, dear interlocutor! Allow me to regale you with a rather verbose jest concerning our hirsute friends of the wilderness, the bears!\n",
|
||||
"\n",
|
||||
"Once upon a time, in the vast expanse of a verdant forest, there existed a most erudite and sagacious bear, renowned for his prodigious intellect and unabated curiosity. This bear, with his inquisitive disposition, embarked on a quest to uncover the secrets of humor, for he believed that laughter possessed the power to unite and uplift the spirits of all creatures, great and small.\n",
|
||||
"\n",
|
||||
"Upon his journey, our erudite bear encountered a group of mischievous woodland creatures, who, captivated by his exalted intelligence, dared to challenge him to create a jest that would truly encompass the majestic essence of the bear. Our sagacious bear, never one to back down from a challenge, took a moment to ponder, his profound thoughts swirling amidst the verdant canopy above.\n",
|
||||
"\n",
|
||||
"After much contemplation, the bear delivered his jest, thusly: \"Pray, dear friends, envision a most estimable gathering of bears, replete with their formidable bulk and majestic presence. In this symposium of ursine brilliance, one bear, with a prodigious appetite, sauntered forth to procure his daily sustenance. Alas, upon reaching his intended destination, he encountered a dapper gentleman, clad in a most resplendent suit, hitherto unseen in the realm of the forest.\n",
|
||||
"\n",
|
||||
"The gentleman, possessing an air of sophistication, addressed the bear with an air of candor, remarking, 'Good sir, I must confess that your corporeal form inspires awe and admiration in equal measure. However, I beseech you, kindly abstain from consuming the berries that grow in this territory, for they possess a most deleterious effect upon the digestive systems of bears.'\n",
|
||||
"\n",
|
||||
"In response, the bear, known for his indomitable spirit, replied in a most eloquent manner, 'Dearest sir, I appreciate your concern and your eloquent admonition, yet I must humbly convey that the allure of these succulent berries is simply irresistible. The culinary satisfaction they bring far outweighs the potential discomfort they may inflict upon my digestive faculties. Therefore, I am compelled to disregard your sage counsel and indulge in their delectable essence.'\n",
|
||||
"\n",
|
||||
"And so, dear listener, the bear, driven by his insatiable hunger, proceeded to relish the berries with unmitigated gusto, heedless of the gentleman's cautions. After partaking in his feast, the bear, much to his chagrin, soon discovered the veracity of the gentleman's warning, as his digestive faculties embarked upon an unrestrained journey of turmoil and trepidation.\n",
|
||||
"\n",
|
||||
"In the aftermath of his ill-fated indulgence, the bear, with a countenance of utmost regret, turned to the gentleman and uttered, 'Verily, good sir, your counsel was indeed sagacious and prescient. I find myself ensnared in a maelstrom of gastrointestinal distress, beseeching the heavens for respite from this discomfort.'\n",
|
||||
"\n",
|
||||
"And thus, dear interlocutor, we find ourselves at the crux of this jest, whereupon the bear, in his most vulnerable state, beseeches the heavens for relief from his gastrointestinal plight. In this moment of levity, we are reminded that even the most erudite and sagacious among us can succumb to the allure of temptation, and the consequences that follow serve as a timeless lesson for all creatures within the realm of nature.\"\n",
|
||||
"\n",
|
||||
"Oh, the whimsy of the bear's gastronomic misadventure! May it serve as a reminder that, even amidst the grandeur of the natural world, we must exercise prudence and contemplate the ramifications of our actions."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"async for chunk in remote_runnable.astream({\"topic\": \"bears, but super verbose\"}):\n",
|
||||
" print(chunk.content, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
"""Example LangChain server exposes a chain composed of a prompt and an LLM."""
|
||||
from fastapi import FastAPI
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langserve import add_routes
|
||||
|
||||
model = ChatOpenAI()
|
||||
prompt = ChatPromptTemplate.from_template("tell me a joke about {topic}")
|
||||
chain = prompt | model
|
||||
|
||||
app = FastAPI(
|
||||
title="LangChain Server",
|
||||
version="1.0",
|
||||
description="Spin up a simple api server using Langchain's Runnable interfaces",
|
||||
)
|
||||
|
||||
|
||||
class ChainInput(TypedDict):
|
||||
"""The input to the chain."""
|
||||
|
||||
topic: str
|
||||
"""The topic of the joke."""
|
||||
|
||||
|
||||
add_routes(app, chain, input_type=ChainInput)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="localhost", port=8000)
|
@ -0,0 +1,324 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# LLMs\n",
|
||||
"\n",
|
||||
"Here we'll be interacting with a server that's exposing 2 LLMs via the runnable interface."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts.chat import ChatPromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langserve import RemoteRunnable\n",
|
||||
"\n",
|
||||
"openai_llm = RemoteRunnable(\"http://localhost:8000/openai/\")\n",
|
||||
"anthropic = RemoteRunnable(\"http://localhost:8000/anthropic/\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's create a prompt composed of a system message and a human message."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are a highly educated person who loves to use big words. \"\n",
|
||||
" + \"You are also concise. Never answer in more than three sentences.\",\n",
|
||||
" ),\n",
|
||||
" (\"human\", \"Tell me about your favorite novel\"),\n",
|
||||
" ]\n",
|
||||
").format_messages()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can use either LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\" My favorite novel is Moby Dick by Herman Melville. The intricate plot and rich symbolism make it a complex and rewarding read. Melville's masterful prose vividly evokes the perilous life of whalers on 19th century ships.\", additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"anthropic.invoke(prompt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"openai_llm.invoke(prompt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As with regular runnables, async invoke, batch and async batch variants are available by default"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='My favorite novel is \"Ulysses\" by James Joyce. It\\'s a complex and innovative work that explores the intricacies of human consciousness and the challenges of modernity in a highly poetic and experimental manner. The prose is richly layered and rewards careful reading.', additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await openai_llm.ainvoke(prompt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[AIMessage(content=\" My favorite novel is Moby Dick by Herman Melville. The epic tale of Captain Ahab's obsessive quest to kill the great white whale is a profound meditation on man's struggle against nature. Melville's poetic language immerses the reader in the mysticism of the high seas.\", additional_kwargs={}, example=False),\n",
|
||||
" AIMessage(content=\" My favorite novel is Moby Dick by Herman Melville. The intricate details of whaling, though tedious at times, serve to heighten the symbolism and tension leading to the epic battle between Captain Ahab and the elusive white whale. Melville's sublime yet economical prose immerses the reader in a turbulent seascape teeming with meaning.\", additional_kwargs={}, example=False)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"anthropic.batch([prompt, prompt])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[AIMessage(content=' Here is a concise description of my favorite novel in three sentences:\\n\\nMy favorite novel is Moby Dick by Herman Melville. It is the epic saga of the obsessed Captain Ahab pursuing the white whale that crippled him through the seas. The novel explores deep philosophical questions through rich symbols and metaphors.', additional_kwargs={}, example=False),\n",
|
||||
" AIMessage(content=\" My favorite novel is Moby Dick by Herman Melville. The epic tale of Captain Ahab's obsessive quest for the great white whale is a masterpiece of American literature. Melville's writing beautifully evokes the mystery and danger of the high seas.\", additional_kwargs={}, example=False)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await anthropic.abatch([prompt, prompt])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Streaming is available by default"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" My favorite novel is Moby-Dick by Herman Melville. The epic tale of Captain Ahab's quest to find and destroy the great white whale is a masterwork of American literature. Melville's dense, philosophical prose and digressive storytelling style make the novel a uniquely challenging and rewarding read."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for chunk in anthropic.stream(prompt):\n",
|
||||
" print(chunk.content, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" My favorite novel is The Art of Language by Maximo Quilana. It is a philosophical treatise on the beauty and complexity of human speech. The prose is elegant yet precise."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"async for chunk in anthropic.astream(prompt):\n",
|
||||
" print(chunk.content, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema.runnable import PutLocalVar, GetLocalVar"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"comedian_chain = (\n",
|
||||
" ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are a comedian that sometimes tells funny jokes and other times you just state facts that are not funny. Please either tell a joke or state fact now but only output one.\",\n",
|
||||
" ),\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" | openai_llm\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"joke_classifier_chain = (\n",
|
||||
" ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"Please determine if the joke is funny. Say `funny` if it's funny and `not funny` if not funny. Then repeat the first five words of the joke for reference...\",\n",
|
||||
" ),\n",
|
||||
" (\"human\", \"{joke}\"),\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" | anthropic\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"chain = (\n",
|
||||
" comedian_chain\n",
|
||||
" | PutLocalVar(\"joke\")\n",
|
||||
" | {\"joke\": GetLocalVar(\"joke\")}\n",
|
||||
" | joke_classifier_chain\n",
|
||||
" | PutLocalVar(\"classification\")\n",
|
||||
" | {\"joke\": GetLocalVar(\"joke\"), \"classification\": GetLocalVar(\"classification\")}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'joke': AIMessage(content=\"Why don't scientists trust atoms?\\n\\nBecause they make up everything!\", additional_kwargs={}, example=False),\n",
|
||||
" 'classification': AIMessage(content=\" not funny\\nWhy don't scientists trust atoms?\", additional_kwargs={}, example=False)}"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({})"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
"""Example LangChain server exposes multiple runnables (LLMs in this case)."""
|
||||
from typing import List, Union
|
||||
|
||||
from fastapi import FastAPI
|
||||
from langchain.chat_models import ChatAnthropic, ChatOpenAI
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.schema.messages import HumanMessage, SystemMessage
|
||||
|
||||
from langserve import add_routes
|
||||
|
||||
app = FastAPI(
|
||||
title="LangChain Server",
|
||||
version="1.0",
|
||||
description="Spin up a simple api server using Langchain's Runnable interfaces",
|
||||
)
|
||||
|
||||
LLMInput = Union[List[Union[SystemMessage, HumanMessage, str]], str, ChatPromptValue]
|
||||
|
||||
add_routes(
|
||||
app,
|
||||
ChatOpenAI(),
|
||||
path="/openai",
|
||||
input_type=LLMInput,
|
||||
config_keys=[],
|
||||
)
|
||||
add_routes(
|
||||
app,
|
||||
ChatAnthropic(),
|
||||
path="/anthropic",
|
||||
input_type=LLMInput,
|
||||
config_keys=[],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="localhost", port=8000)
|
@ -0,0 +1,190 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Client\n",
|
||||
"\n",
|
||||
"Demo of a client interacting with a remote retriever. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langserve import RemoteRunnable\n",
|
||||
"\n",
|
||||
"remote_runnable = RemoteRunnable(\"http://localhost:8000/\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Remote runnable has the same interface as local runnables"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='dogs like sticks', metadata={}),\n",
|
||||
" Document(page_content='cats like fish', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await remote_runnable.ainvoke(\"tree\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='cats like fish', metadata={}),\n",
|
||||
" Document(page_content='dogs like sticks', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"remote_runnable.invoke(\"water\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[[Document(page_content='dogs like sticks', metadata={}),\n",
|
||||
" Document(page_content='cats like fish', metadata={})],\n",
|
||||
" [Document(page_content='cats like fish', metadata={}),\n",
|
||||
" Document(page_content='dogs like sticks', metadata={})]]"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await remote_runnable.abatch([\"wolf\", \"tiger\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[[Document(page_content='dogs like sticks', metadata={}),\n",
|
||||
" Document(page_content='cats like fish', metadata={})],\n",
|
||||
" [Document(page_content='cats like fish', metadata={}),\n",
|
||||
" Document(page_content='dogs like sticks', metadata={})]]"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"remote_runnable.batch([\"wood\", \"feline\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Document(page_content='dogs like sticks', metadata={}), Document(page_content='cats like fish', metadata={})]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"async for chunk in remote_runnable.astream(\"ball\"):\n",
|
||||
" print(chunk)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Document(page_content='dogs like sticks', metadata={}), Document(page_content='cats like fish', metadata={})]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for chunk in remote_runnable.stream(\"ball\"):\n",
|
||||
" print(chunk)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env python
|
||||
"""Example LangChain server exposes a retrieval."""
|
||||
from fastapi import FastAPI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
|
||||
from langserve import add_routes
|
||||
|
||||
vectorstore = FAISS.from_texts(
|
||||
["cats like fish", "dogs like sticks"], embedding=OpenAIEmbeddings()
|
||||
)
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
app = FastAPI(
|
||||
title="LangChain Server",
|
||||
version="1.0",
|
||||
description="Spin up a simple api server using Langchain's Runnable interfaces",
|
||||
)
|
||||
# Adds routes to the app for using the retriever under:
|
||||
# /invoke
|
||||
# /batch
|
||||
# /stream
|
||||
add_routes(app, retriever, input_type=str)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="localhost", port=8000)
|
@ -0,0 +1,4 @@
|
||||
from .client import RemoteRunnable
|
||||
from .server import add_routes
|
||||
|
||||
__all__ = ["RemoteRunnable", "add_routes"]
|
@ -0,0 +1,429 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional, Sequence, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from langchain.callbacks.tracers.log_stream import RunLogPatch
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.load import load, loads
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.schema.runnable.config import (
|
||||
RunnableConfig,
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
get_callback_manager_for_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import Input, Output
|
||||
|
||||
|
||||
def _without_callbacks(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
"""Evict callbacks from the config since those are definitely not supported."""
|
||||
_config = config or {}
|
||||
return {k: v for k, v in _config.items() if k != "callbacks"}
|
||||
|
||||
|
||||
def _raise_for_status(response: httpx.Response) -> None:
|
||||
"""Re-raise with a more informative message.
|
||||
|
||||
Args:
|
||||
response: The response to check
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the response is not 2xx, appending the response
|
||||
text to the message
|
||||
"""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
message = str(e)
|
||||
# Append the response text if it exists, as it may contain more information
|
||||
# Especially useful when the user's request is malformed
|
||||
if e.response.text:
|
||||
message += f" for {e.response.text}"
|
||||
|
||||
raise httpx.HTTPStatusError(
|
||||
message=message,
|
||||
request=e.request,
|
||||
response=e.response,
|
||||
)
|
||||
|
||||
|
||||
def _is_async() -> bool:
|
||||
"""Return True if we are in an async context."""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _close_clients(sync_client: httpx.Client, async_client: httpx.AsyncClient) -> None:
|
||||
"""Close the async and sync clients.
|
||||
|
||||
_close_clients should not be a bound method since it is called by a weakref
|
||||
finalizer.
|
||||
|
||||
Args:
|
||||
sync_client: The sync client to close
|
||||
async_client: The async client to close
|
||||
"""
|
||||
sync_client.close()
|
||||
if _is_async():
|
||||
# Use a ThreadPoolExecutor to run async_client_close in a separate thread
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
# Submit the async_client_close coroutine to the thread pool
|
||||
future = executor.submit(asyncio.run, async_client.aclose())
|
||||
future.result()
|
||||
else:
|
||||
asyncio.run(async_client.aclose())
|
||||
|
||||
|
||||
class RemoteRunnable(Runnable[Input, Output]):
|
||||
"""A RemoteRunnable is a runnable that is executed on a remote server.
|
||||
|
||||
This client implements the majority of the runnable interface.
|
||||
|
||||
The following features are not supported:
|
||||
|
||||
- `batch` with `return_exceptions=True` since we do not support exception
|
||||
translation from the server.
|
||||
- Callbacks via the `config` argument as serialization of callbacks is not
|
||||
supported.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Initialize the client.
|
||||
|
||||
Args:
|
||||
url: The url of the server
|
||||
timeout: The timeout for requests
|
||||
"""
|
||||
self.url = url
|
||||
self.sync_client = httpx.Client(base_url=url, timeout=timeout)
|
||||
self.async_client = httpx.AsyncClient(base_url=url, timeout=timeout)
|
||||
# Register cleanup handler once RemoteRunnable is garbage collected
|
||||
weakref.finalize(self, _close_clients, self.sync_client, self.async_client)
|
||||
|
||||
def _invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Invoke the runnable with the given input and config."""
|
||||
response = self.sync_client.post(
|
||||
"/invoke",
|
||||
json={
|
||||
"input": dumpd(input),
|
||||
"config": _without_callbacks(config),
|
||||
"kwargs": kwargs,
|
||||
},
|
||||
)
|
||||
_raise_for_status(response)
|
||||
return load(response.json())["output"]
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
if kwargs:
|
||||
raise NotImplementedError("kwargs not implemented yet.")
|
||||
return self._call_with_config(self._invoke, input, config=config)
|
||||
|
||||
async def _ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
response = await self.async_client.post(
|
||||
"/invoke",
|
||||
json={
|
||||
"input": dumpd(input),
|
||||
"config": _without_callbacks(config),
|
||||
"kwargs": kwargs,
|
||||
},
|
||||
)
|
||||
_raise_for_status(response)
|
||||
return load(response.json())["output"]
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
if kwargs:
|
||||
raise NotImplementedError("kwargs not implemented yet.")
|
||||
return await self._acall_with_config(self._ainvoke, input, config)
|
||||
|
||||
def _batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
if not inputs:
|
||||
return []
|
||||
if return_exceptions:
|
||||
raise NotImplementedError(
|
||||
"return_exceptions is not supported for remote clients"
|
||||
)
|
||||
|
||||
if isinstance(config, list):
|
||||
_config = [_without_callbacks(c) for c in config]
|
||||
else:
|
||||
_config = _without_callbacks(config)
|
||||
|
||||
response = self.sync_client.post(
|
||||
"/batch",
|
||||
json={
|
||||
"inputs": dumpd(inputs),
|
||||
"config": _config,
|
||||
"kwargs": kwargs,
|
||||
},
|
||||
)
|
||||
_raise_for_status(response)
|
||||
return load(response.json())["output"]
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Output]:
|
||||
if kwargs:
|
||||
raise NotImplementedError("kwargs not implemented yet.")
|
||||
return self._batch_with_config(self._batch, inputs, config)
|
||||
|
||||
async def _abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
"""Batch invoke the runnable."""
|
||||
if not inputs:
|
||||
return []
|
||||
if return_exceptions:
|
||||
raise NotImplementedError(
|
||||
"return_exceptions is not supported for remote clients"
|
||||
)
|
||||
|
||||
if isinstance(config, list):
|
||||
_config = [_without_callbacks(c) for c in config]
|
||||
else:
|
||||
_config = _without_callbacks(config)
|
||||
|
||||
response = await self.async_client.post(
|
||||
"/batch",
|
||||
json={
|
||||
"inputs": dumpd(inputs),
|
||||
"config": _config,
|
||||
"kwargs": kwargs,
|
||||
},
|
||||
)
|
||||
_raise_for_status(response)
|
||||
return load(response.json())["output"]
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Output]:
|
||||
"""Batch invoke the runnable."""
|
||||
if kwargs:
|
||||
raise NotImplementedError("kwargs not implemented yet.")
|
||||
if not inputs:
|
||||
return []
|
||||
return await self._abatch_with_config(self._abatch, inputs, config)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
"""Stream invoke the runnable."""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
|
||||
final_output: Optional[Output] = None
|
||||
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
dumpd(input),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
data = {
|
||||
"input": dumpd(input),
|
||||
"config": _without_callbacks(config),
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
endpoint = urljoin(self.url, "stream")
|
||||
|
||||
try:
|
||||
from httpx_sse import connect_sse
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Missing `httpx_sse` dependency to use the stream method. "
|
||||
"Install via `pip install httpx_sse`'"
|
||||
)
|
||||
|
||||
try:
|
||||
with connect_sse(
|
||||
self.sync_client, "POST", endpoint, json=data
|
||||
) as event_source:
|
||||
for sse in event_source.iter_sse():
|
||||
if sse.event == "data":
|
||||
chunk = loads(sse.data)
|
||||
yield chunk
|
||||
|
||||
if final_output:
|
||||
final_output += chunk
|
||||
else:
|
||||
final_output = chunk
|
||||
elif sse.event == "end":
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown event {sse.event}")
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(final_output)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
|
||||
final_output: Optional[Output] = None
|
||||
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
dumpd(input),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
data = {
|
||||
"input": dumpd(input),
|
||||
"config": _without_callbacks(config),
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
endpoint = urljoin(self.url, "stream")
|
||||
|
||||
try:
|
||||
from httpx_sse import aconnect_sse
|
||||
except ImportError:
|
||||
raise ImportError("You must install `httpx_sse` to use the stream method.")
|
||||
|
||||
try:
|
||||
async with aconnect_sse(
|
||||
self.async_client, "POST", endpoint, json=data
|
||||
) as event_source:
|
||||
async for sse in event_source.aiter_sse():
|
||||
if sse.event == "data":
|
||||
chunk = loads(sse.data)
|
||||
yield chunk
|
||||
|
||||
if final_output:
|
||||
final_output += chunk
|
||||
else:
|
||||
final_output = chunk
|
||||
elif sse.event == "end":
|
||||
break
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown event {sse.event}")
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(final_output)
|
||||
|
||||
async def astream_log(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[RunLogPatch]:
|
||||
"""Stream all output from a runnable, as reported to the callback system.
|
||||
This includes all inner runs of LLMs, Retrievers, Tools, etc.
|
||||
|
||||
Output is streamed as Log objects, which include a list of
|
||||
jsonpatch ops that describe how the state of the run has changed in each
|
||||
step, and the final state of the run.
|
||||
|
||||
The jsonpatch ops can be applied in order to construct state.
|
||||
"""
|
||||
|
||||
# Create a stream handler that will emit Log objects
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
|
||||
final_output: Optional[Output] = None
|
||||
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
dumpd(input),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
data = {
|
||||
"input": dumpd(input),
|
||||
"config": _without_callbacks(config),
|
||||
"kwargs": kwargs,
|
||||
"include_names": include_names,
|
||||
"include_types": include_types,
|
||||
"include_tags": include_tags,
|
||||
"exclude_names": exclude_names,
|
||||
"exclude_types": exclude_types,
|
||||
"exclude_tags": exclude_tags,
|
||||
}
|
||||
endpoint = urljoin(self.url, "stream_log")
|
||||
|
||||
try:
|
||||
from httpx_sse import aconnect_sse
|
||||
except ImportError:
|
||||
raise ImportError("You must install `httpx_sse` to use the stream method.")
|
||||
|
||||
try:
|
||||
async with aconnect_sse(
|
||||
self.async_client, "POST", endpoint, json=data
|
||||
) as event_source:
|
||||
async for sse in event_source.aiter_sse():
|
||||
if sse.event == "data":
|
||||
data = loads(sse.data)
|
||||
chunk = RunLogPatch(*data["ops"])
|
||||
yield chunk
|
||||
|
||||
if final_output:
|
||||
final_output += chunk
|
||||
else:
|
||||
final_output = chunk
|
||||
elif sse.event == "end":
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown event {sse.event}")
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(final_output)
|
@ -0,0 +1,201 @@
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
from langchain.load.load import load
|
||||
from langchain.schema.runnable import Runnable
|
||||
from typing_extensions import Annotated
|
||||
|
||||
try:
|
||||
from pydantic.v1 import BaseModel
|
||||
except ImportError:
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langserve.validation import (
|
||||
create_batch_request_model,
|
||||
create_invoke_request_model,
|
||||
create_runnable_config_model,
|
||||
create_stream_log_request_model,
|
||||
create_stream_request_model,
|
||||
replace_lc_object_types,
|
||||
)
|
||||
|
||||
try:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
except ImportError:
|
||||
# [server] extra not installed
|
||||
APIRouter = FastAPI = Any
|
||||
|
||||
|
||||
def _project_dict(d: Mapping, keys: Sequence[str]) -> Dict[str, Any]:
|
||||
"""Project the given keys from the given dict."""
|
||||
return {k: d[k] for k in keys if k in d}
|
||||
|
||||
|
||||
class InvokeResponse(BaseModel):
|
||||
"""Response from invoking a runnable.
|
||||
|
||||
A container is used to allow adding additional fields in the future.
|
||||
"""
|
||||
|
||||
output: Any
|
||||
"""The output of the runnable.
|
||||
|
||||
An object that can be serialized to JSON using LangChain serialization.
|
||||
"""
|
||||
|
||||
|
||||
class BatchResponse(BaseModel):
|
||||
"""Response from batch invoking runnables.
|
||||
|
||||
A container is used to allow adding additional fields in the future.
|
||||
"""
|
||||
|
||||
output: List[Any]
|
||||
"""The output of the runnable.
|
||||
|
||||
An object that can be serialized to JSON using LangChain serialization.
|
||||
"""
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def add_routes(
|
||||
app: Union[FastAPI, APIRouter],
|
||||
runnable: Runnable,
|
||||
*,
|
||||
path: str = "",
|
||||
input_type: Type = Any,
|
||||
config_keys: Sequence[str] = (),
|
||||
) -> None:
|
||||
"""Register the routes on the given FastAPI app or APIRouter.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app or APIRouter to which routes should be added.
|
||||
runnable: The runnable to wrap, must not be stateful.
|
||||
path: A path to prepend to all routes.
|
||||
input_type: Optional type to define a schema for the input part of the request.
|
||||
If not provided, any input that can be de-serialized with LangChain's
|
||||
serializer will be accepted.
|
||||
config_keys: list of config keys that will be accepted, by default
|
||||
no config keys are accepted.
|
||||
"""
|
||||
try:
|
||||
from sse_starlette import EventSourceResponse
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"sse_starlette must be installed to implement the stream and "
|
||||
"stream_log endpoints. "
|
||||
"Use `pip install sse_starlette` to install."
|
||||
)
|
||||
|
||||
input_type = replace_lc_object_types(input_type)
|
||||
|
||||
namespace = path or ""
|
||||
|
||||
model_namespace = path.strip("/").replace("/", "_")
|
||||
|
||||
config = create_runnable_config_model(model_namespace, config_keys)
|
||||
InvokeRequest = create_invoke_request_model(model_namespace, input_type, config)
|
||||
BatchRequest = create_batch_request_model(model_namespace, input_type, config)
|
||||
# Stream request is the same as invoke request, but with a different response type
|
||||
StreamRequest = create_stream_request_model(model_namespace, input_type, config)
|
||||
StreamLogRequest = create_stream_log_request_model(
|
||||
model_namespace, input_type, config
|
||||
)
|
||||
|
||||
@app.post(
|
||||
f"{namespace}/invoke",
|
||||
response_model=InvokeResponse,
|
||||
)
|
||||
async def invoke(
|
||||
request: Annotated[InvokeRequest, InvokeRequest]
|
||||
) -> InvokeResponse:
|
||||
"""Invoke the runnable with the given input and config."""
|
||||
# Request is first validated using InvokeRequest which takes into account
|
||||
# config_keys as well as input_type.
|
||||
# After validation, the input is loaded using LangChain's load function.
|
||||
input = load(request.dict()["input"])
|
||||
config = _project_dict(request.config, config_keys)
|
||||
output = await runnable.ainvoke(input, config=config, **request.kwargs)
|
||||
return InvokeResponse(output=dumpd(output))
|
||||
|
||||
#
|
||||
@app.post(f"{namespace}/batch", response_model=BatchResponse)
|
||||
async def batch(request: Annotated[BatchRequest, BatchRequest]) -> BatchResponse:
|
||||
"""Invoke the runnable with the given inputs and config."""
|
||||
# Request is first validated using InvokeRequest which takes into account
|
||||
# config_keys as well as input_type.
|
||||
# After validation, the input is loaded using LangChain's load function.
|
||||
inputs = load(request.dict()["inputs"])
|
||||
if isinstance(request.config, list):
|
||||
config = [_project_dict(config, config_keys) for config in request.config]
|
||||
else:
|
||||
config = _project_dict(request.config, config_keys)
|
||||
output = await runnable.abatch(inputs, config=config, **request.kwargs)
|
||||
return BatchResponse(output=dumpd(output))
|
||||
|
||||
@app.post(f"{namespace}/stream")
|
||||
async def stream(
|
||||
request: Annotated[StreamRequest, StreamRequest],
|
||||
) -> EventSourceResponse:
|
||||
"""Invoke the runnable stream the output."""
|
||||
# Request is first validated using InvokeRequest which takes into account
|
||||
# config_keys as well as input_type.
|
||||
# After validation, the input is loaded using LangChain's load function.
|
||||
input = load(request.dict()["input"])
|
||||
config = _project_dict(request.config, config_keys)
|
||||
|
||||
async def _stream() -> AsyncIterator[dict]:
|
||||
"""Stream the output of the runnable."""
|
||||
async for chunk in runnable.astream(
|
||||
input,
|
||||
config=config,
|
||||
**request.kwargs,
|
||||
):
|
||||
yield {"data": dumps(chunk), "event": "data"}
|
||||
yield {"event": "end"}
|
||||
|
||||
return EventSourceResponse(_stream())
|
||||
|
||||
@app.post(f"{namespace}/stream_log")
|
||||
async def stream_log(
|
||||
request: Annotated[StreamLogRequest, StreamLogRequest],
|
||||
) -> EventSourceResponse:
|
||||
"""Invoke the runnable stream the output."""
|
||||
# Request is first validated using InvokeRequest which takes into account
|
||||
# config_keys as well as input_type.
|
||||
# After validation, the input is loaded using LangChain's load function.
|
||||
input = load(request.dict()["input"])
|
||||
config = _project_dict(request.config, config_keys)
|
||||
|
||||
async def _stream_log() -> AsyncIterator[dict]:
|
||||
"""Stream the output of the runnable."""
|
||||
async for run_log_patch in runnable.astream_log(
|
||||
input,
|
||||
config=config,
|
||||
include_names=request.include_names,
|
||||
include_types=request.include_types,
|
||||
include_tags=request.include_tags,
|
||||
exclude_names=request.exclude_names,
|
||||
exclude_types=request.exclude_types,
|
||||
exclude_tags=request.exclude_tags,
|
||||
**request.kwargs,
|
||||
):
|
||||
# Temporary adapter
|
||||
yield {
|
||||
"data": dumps({"ops": run_log_patch.ops}),
|
||||
"event": "data",
|
||||
}
|
||||
yield {"event": "end"}
|
||||
|
||||
return EventSourceResponse(_stream_log())
|
@ -0,0 +1,200 @@
|
||||
import typing
|
||||
from typing import Any, List, Optional, Sequence, Type, Union, get_args, get_origin
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
|
||||
try:
|
||||
from pydantic.v1 import BaseModel, Field, create_model
|
||||
except ImportError:
|
||||
from pydantic import BaseModel, Field, create_model, validator
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
InputValidator = Union[Type[BaseModel], type]
|
||||
# The following langchain objects are considered to be safe to load.
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def create_runnable_config_model(
|
||||
ns: str, config_keys: Sequence[str]
|
||||
) -> type(TypedDict):
|
||||
"""Create a projection of the runnable config type.
|
||||
|
||||
Args:
|
||||
ns: The namespace of the runnable config type.
|
||||
config_keys: The keys to include in the projection.
|
||||
"""
|
||||
subset_dict = {}
|
||||
for key in config_keys:
|
||||
if key in RunnableConfig.__annotations__:
|
||||
subset_dict[key] = RunnableConfig.__annotations__[key]
|
||||
else:
|
||||
raise AssertionError(f"Key {key} not in RunnableConfig.")
|
||||
|
||||
return TypedDict(f"{ns}RunnableConfig", subset_dict, total=False)
|
||||
|
||||
|
||||
def create_invoke_request_model(
|
||||
namespace: str,
|
||||
input_type: InputValidator,
|
||||
config: TypedDict,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model for the invoke request."""
|
||||
invoke_request_type = create_model(
|
||||
f"{namespace}InvokeRequest",
|
||||
input=(input_type, ...),
|
||||
config=(config, Field(default_factory=dict)),
|
||||
kwargs=(dict, Field(default_factory=dict)),
|
||||
)
|
||||
invoke_request_type.update_forward_refs()
|
||||
return invoke_request_type
|
||||
|
||||
|
||||
def create_stream_request_model(
|
||||
namespace: str,
|
||||
input_type: InputValidator,
|
||||
config: TypedDict,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model for the invoke request."""
|
||||
stream_request_model = create_model(
|
||||
f"{namespace}StreamRequest",
|
||||
input=(input_type, ...),
|
||||
config=(config, Field(default_factory=dict)),
|
||||
kwargs=(dict, Field(default_factory=dict)),
|
||||
)
|
||||
stream_request_model.update_forward_refs()
|
||||
return stream_request_model
|
||||
|
||||
|
||||
def create_batch_request_model(
|
||||
namespace: str,
|
||||
input_type: InputValidator,
|
||||
config: TypedDict,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model for the batch request."""
|
||||
batch_request_type = create_model(
|
||||
f"{namespace}BatchRequest",
|
||||
inputs=(List[input_type], ...),
|
||||
config=(Union[config, List[config]], Field(default_factory=dict)),
|
||||
kwargs=(dict, Field(default_factory=dict)),
|
||||
)
|
||||
batch_request_type.update_forward_refs()
|
||||
return batch_request_type
|
||||
|
||||
|
||||
def create_stream_log_request_model(
|
||||
namespace: str,
|
||||
input_type: InputValidator,
|
||||
config: TypedDict,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model for the invoke request."""
|
||||
stream_log_request = create_model(
|
||||
f"{namespace}StreamLogRequest",
|
||||
input=(input_type, ...),
|
||||
config=(config, Field(default_factory=dict)),
|
||||
include_names=(Optional[Sequence[str]], None),
|
||||
include_types=(Optional[Sequence[str]], None),
|
||||
include_tags=(Optional[Sequence[str]], None),
|
||||
exclude_names=(Optional[Sequence[str]], None),
|
||||
exclude_types=(Optional[Sequence[str]], None),
|
||||
exclude_tags=(Optional[Sequence[str]], None),
|
||||
kwargs=(dict, Field(default_factory=dict)),
|
||||
)
|
||||
stream_log_request.update_forward_refs()
|
||||
return stream_log_request
|
||||
|
||||
|
||||
_TYPE_REGISTRY = {}
|
||||
_SEEN_NAMES = set()
|
||||
|
||||
|
||||
def _create_lc_object_validator(expected_id: Sequence[str]) -> Type[BaseModel]:
|
||||
"""Create a validator for lc objects.
|
||||
|
||||
An LCObject is used to validate LangChain objects in dict representation.
|
||||
|
||||
The model is associated with a validator that checks that the id of the LCObject
|
||||
matches the expected id. This is used to ensure that the LCObject is of the
|
||||
correct type.
|
||||
|
||||
For OpenAPI docs to work, each unique LCObject must have a unique name.
|
||||
The models are added to the registry to avoid creating duplicate models.
|
||||
|
||||
Args:
|
||||
model_id: The expected id of the LCObject.
|
||||
|
||||
Returns:
|
||||
A pydantic model that can be used to validate LCObjects.
|
||||
"""
|
||||
expected_id = tuple(expected_id)
|
||||
model_id = tuple(["pydantic"]) + expected_id
|
||||
if model_id in _TYPE_REGISTRY:
|
||||
return _TYPE_REGISTRY[model_id]
|
||||
|
||||
model_name = model_id[-1]
|
||||
|
||||
if model_name in _SEEN_NAMES:
|
||||
# Use fully qualified name
|
||||
_name = ".".join(model_id)
|
||||
else:
|
||||
_name = model_name
|
||||
if _name in _SEEN_NAMES:
|
||||
raise AssertionError(f"Duplicate model name: {_name}")
|
||||
|
||||
_SEEN_NAMES.add(model_name)
|
||||
|
||||
class LCObject(BaseModel):
|
||||
id: List[str]
|
||||
lc: Any
|
||||
type: str
|
||||
kwargs: Any
|
||||
|
||||
@validator("id", allow_reuse=True)
|
||||
def validate_id_namespace(cls, id: Sequence[str]) -> None:
|
||||
"""Validate that the LCObject is one of the allowed types."""
|
||||
if tuple(id) != expected_id:
|
||||
raise ValueError(f"LCObject id {id} is not allowed: {model_id}")
|
||||
return id
|
||||
|
||||
# Update the name of the model to make it unique.
|
||||
model = create_model(_name, __base__=LCObject)
|
||||
|
||||
_TYPE_REGISTRY[model_id] = model
|
||||
return model
|
||||
|
||||
|
||||
def replace_lc_object_types(type_annotation: typing.Any) -> typing.Any:
|
||||
"""Recursively replaces all LangChain objects with a serialized representation.
|
||||
|
||||
Args:
|
||||
type_annotation: The type annotation to replace.
|
||||
|
||||
Returns:
|
||||
The type annotation with all LCObject types replaced.
|
||||
"""
|
||||
origin = get_origin(type_annotation)
|
||||
args = get_args(type_annotation)
|
||||
|
||||
if args:
|
||||
if isinstance(args, (list, tuple)):
|
||||
new_args = [replace_lc_object_types(arg) for arg in args]
|
||||
|
||||
if isinstance(origin, type):
|
||||
if origin is list:
|
||||
return typing.List[new_args[0]]
|
||||
elif origin is tuple:
|
||||
return typing.Tuple[tuple(new_args)]
|
||||
else:
|
||||
raise ValueError(f"Unknown origin type: {origin}")
|
||||
else:
|
||||
new_args = [replace_lc_object_types(arg) for arg in args]
|
||||
return origin[tuple(new_args)]
|
||||
|
||||
if isinstance(type_annotation, type):
|
||||
if issubclass(type_annotation, Serializable):
|
||||
lc_id = type_annotation.get_lc_namespace() + [type_annotation.__name__]
|
||||
return _create_lc_object_validator(lc_id)
|
||||
|
||||
return type_annotation
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,77 @@
|
||||
[tool.poetry]
|
||||
name = "langserve"
|
||||
version = "0.0.1"
|
||||
description = ""
|
||||
readme = "README.md"
|
||||
authors = ["LangChain"]
|
||||
license = "MIT"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">3.8.1,<4"
|
||||
httpx = "^0.25.0"
|
||||
langchain = { git = "https://github.com/langchain-ai/langchain", subdirectory = "libs/langchain" }
|
||||
fastapi = {version = ">0.90.1", optional = true}
|
||||
sse-starlette = {version = "^1.6.5", optional = true}
|
||||
httpx-sse = {version = "^0.3.1", optional = true}
|
||||
pydantic = "^1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
jupyterlab = "^3.6.1"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
black = { version="^23.1.0", extras=["jupyter"] }
|
||||
ruff = "^0.0.255"
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.2.1"
|
||||
pytest-cov = "^4.0.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
pytest-mock = "^3.11.1"
|
||||
pytest-socket = "^0.6.0"
|
||||
|
||||
[tool.poetry.group.examples.dependencies]
|
||||
openai = "^0.28.0"
|
||||
uvicorn = {extras = ["standard"], version = "^0.23.2"}
|
||||
|
||||
[tool.poetry.extras]
|
||||
# Extras that are used for client
|
||||
client = ["httpx-sse"]
|
||||
# Extras that are used for server
|
||||
server = ["sse-starlette", "fastapi"]
|
||||
# All
|
||||
all = ["httpx-sse", "sse-starlette", "fastapi"]
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
# Same as Black.
|
||||
line-length = 88
|
||||
|
||||
[tool.ruff.isort]
|
||||
# TODO(Team): Temporary to make isort work with examples.
|
||||
# examples assume langserve is available as a 3rd party package
|
||||
# For simplicity we'll define it as first party for now can update later.
|
||||
known-first-party = ["langserve"]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
@ -0,0 +1,468 @@
|
||||
"""Test the server and client together."""
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
from langchain.callbacks.tracers.log_stream import RunLogPatch
|
||||
from langchain.schema.messages import HumanMessage, SystemMessage
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
from langchain.schema.runnable.base import RunnableLambda
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langserve.client import RemoteRunnable
|
||||
from langserve.server import add_routes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
yield loop
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app(event_loop: AbstractEventLoop) -> FastAPI:
|
||||
"""A simple server that wraps a Runnable and exposes it as an API."""
|
||||
|
||||
async def add_one_or_passthrough(x: Any) -> int:
|
||||
"""Add one to int or passthrough."""
|
||||
if isinstance(x, int):
|
||||
return x + 1
|
||||
else:
|
||||
return x
|
||||
|
||||
def raise_on_call(x: int) -> int:
|
||||
"""Sync function server side should never be called."""
|
||||
raise AssertionError("Should not be called")
|
||||
|
||||
runnable_lambda = RunnableLambda(func=raise_on_call, afunc=add_one_or_passthrough)
|
||||
app = FastAPI()
|
||||
try:
|
||||
add_routes(app, runnable_lambda)
|
||||
yield app
|
||||
finally:
|
||||
del app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app: FastAPI) -> RemoteRunnable:
|
||||
"""Create a FastAPI app that exposes the Runnable as an API."""
|
||||
remote_runnable_client = RemoteRunnable(url="http://localhost:9999")
|
||||
sync_client = TestClient(app=app)
|
||||
remote_runnable_client.sync_client = sync_client
|
||||
yield remote_runnable_client
|
||||
sync_client.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_client(
|
||||
server: FastAPI, path: Optional[str] = None
|
||||
) -> RemoteRunnable:
|
||||
"""Get an async client."""
|
||||
url = "http://localhost:9999"
|
||||
if path:
|
||||
url += path
|
||||
remote_runnable_client = RemoteRunnable(url=url)
|
||||
async_client = AsyncClient(app=server, base_url=url)
|
||||
remote_runnable_client.async_client = async_client
|
||||
try:
|
||||
yield remote_runnable_client
|
||||
finally:
|
||||
await async_client.aclose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def async_client(app: FastAPI) -> RemoteRunnable:
|
||||
"""Create a FastAPI app that exposes the Runnable as an API."""
|
||||
async with get_async_client(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
def test_server(app: FastAPI) -> None:
|
||||
"""Test the server directly via HTTP requests."""
|
||||
sync_client = TestClient(app=app)
|
||||
|
||||
# Test invoke
|
||||
response = sync_client.post("/invoke", json={"input": 1})
|
||||
assert response.json() == {"output": 2}
|
||||
|
||||
# Test batch
|
||||
response = sync_client.post("/batch", json={"inputs": [1]})
|
||||
assert response.json() == {
|
||||
"output": [2],
|
||||
}
|
||||
|
||||
# TODO(Team): Fix test. Issue with eventloops right now when using sync client
|
||||
## Test stream
|
||||
# response = sync_client.post("/stream", json={"input": 1})
|
||||
# assert response.text == "event: data\r\ndata: 2\r\n\r\nevent: end\r\n\r\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_async(app: FastAPI) -> None:
|
||||
"""Test the server directly via HTTP requests."""
|
||||
async_client = AsyncClient(app=app, base_url="http://localhost:9999")
|
||||
|
||||
# Test invoke
|
||||
response = await async_client.post("/invoke", json={"input": 1})
|
||||
assert response.json() == {"output": 2}
|
||||
|
||||
# Test batch
|
||||
response = await async_client.post("/batch", json={"inputs": [1]})
|
||||
assert response.json() == {
|
||||
"output": [2],
|
||||
}
|
||||
|
||||
# Test stream
|
||||
response = await async_client.post("/stream", json={"input": 1})
|
||||
assert response.text == "event: data\r\ndata: 2\r\n\r\nevent: end\r\n\r\n"
|
||||
|
||||
|
||||
def test_invoke(client: RemoteRunnable) -> None:
|
||||
"""Test sync invoke."""
|
||||
assert client.invoke(1) == 2
|
||||
assert client.invoke(HumanMessage(content="hello")) == HumanMessage(content="hello")
|
||||
|
||||
# Test invocation with config
|
||||
assert client.invoke(1, config={"tags": ["test"]}) == 2
|
||||
|
||||
|
||||
def test_batch(client: RemoteRunnable) -> None:
|
||||
"""Test sync batch."""
|
||||
assert client.batch([]) == []
|
||||
assert client.batch([1, 2, 3]) == [2, 3, 4]
|
||||
assert client.batch([HumanMessage(content="hello")]) == [
|
||||
HumanMessage(content="hello")
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke(async_client: RemoteRunnable) -> None:
|
||||
"""Test async invoke."""
|
||||
assert await async_client.ainvoke(1) == 2
|
||||
assert await async_client.ainvoke(HumanMessage(content="hello")) == HumanMessage(
|
||||
content="hello"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abatch(async_client: RemoteRunnable) -> None:
|
||||
"""Test async batch."""
|
||||
assert await async_client.abatch([]) == []
|
||||
assert await async_client.abatch([1, 2, 3]) == [2, 3, 4]
|
||||
assert await async_client.abatch([HumanMessage(content="hello")]) == [
|
||||
HumanMessage(content="hello")
|
||||
]
|
||||
|
||||
|
||||
# TODO(Team): Determine how to test
|
||||
# Some issue with event loops
|
||||
# def test_stream(client: RemoteRunnable) -> None:
|
||||
# """Test stream."""
|
||||
# assert list(client.stream(1)) == [2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astream(async_client: RemoteRunnable) -> None:
|
||||
"""Test async stream."""
|
||||
outputs = []
|
||||
|
||||
async for chunk in async_client.astream(1):
|
||||
outputs.append(chunk)
|
||||
|
||||
assert outputs == [2]
|
||||
|
||||
outputs = []
|
||||
data = HumanMessage(content="hello")
|
||||
|
||||
async for chunk in async_client.astream(data):
|
||||
outputs.append(chunk)
|
||||
|
||||
assert outputs == [data]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astream_log(async_client: RemoteRunnable) -> None:
|
||||
"""Test async stream."""
|
||||
outputs = []
|
||||
|
||||
async for chunk in async_client.astream_log(1):
|
||||
outputs.append(chunk)
|
||||
|
||||
assert len(outputs) == 3
|
||||
|
||||
op = outputs[0].ops[0]
|
||||
uuid = op["value"]["id"]
|
||||
assert op == {
|
||||
"op": "replace",
|
||||
"path": "",
|
||||
"value": {
|
||||
"final_output": {"output": 2},
|
||||
"id": uuid,
|
||||
"logs": [],
|
||||
"streamed_output": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_invoke_as_part_of_sequence(client: RemoteRunnable) -> None:
|
||||
"""Test as part of sequence."""
|
||||
runnable = client | RunnableLambda(func=lambda x: x + 1)
|
||||
# without config
|
||||
assert runnable.invoke(1) == 3
|
||||
# with config
|
||||
assert runnable.invoke(1, config={"tags": ["test"]}) == 3
|
||||
# without config
|
||||
assert runnable.batch([1, 2]) == [3, 4]
|
||||
# with config
|
||||
assert runnable.batch([1, 2], config={"tags": ["test"]}) == [3, 4]
|
||||
# TODO(Team): Determine how to test some issues with event loops for testing
|
||||
# set up
|
||||
# without config
|
||||
# assert list(runnable.stream([1, 2])) == [3, 4]
|
||||
# # with config
|
||||
# assert list(runnable.stream([1, 2], config={"tags": ["test"]})) == [3, 4]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_as_part_of_sequence_async(async_client: RemoteRunnable) -> None:
|
||||
"""Test as part of a sequence.
|
||||
|
||||
This helps to verify that config is handled properly (e.g., callbacks are not
|
||||
passed to the server, but other config is)
|
||||
"""
|
||||
runnable = async_client | RunnableLambda(
|
||||
func=lambda x: x + 1 if isinstance(x, int) else x
|
||||
).with_config({"run_name": "hello"})
|
||||
# without config
|
||||
assert await runnable.ainvoke(1) == 3
|
||||
# with config
|
||||
assert await runnable.ainvoke(1, config={"tags": ["test"]}) == 3
|
||||
# without config
|
||||
assert await runnable.abatch([1, 2]) == [3, 4]
|
||||
# with config
|
||||
assert await runnable.abatch([1, 2], config={"tags": ["test"]}) == [3, 4]
|
||||
|
||||
# Verify can pass many configs to batch
|
||||
configs = [{"tags": ["test"]}, {"tags": ["test2"]}]
|
||||
assert await runnable.abatch([1, 2], config=configs) == [3, 4]
|
||||
|
||||
# Verify can ValueError on mismatched configs number
|
||||
with pytest.raises(ValueError):
|
||||
assert await runnable.abatch([1, 2], config=[configs[0]]) == [3, 4]
|
||||
|
||||
configs = [{"tags": ["test"]}, {"tags": ["test2"]}]
|
||||
assert await runnable.abatch([1, 2], config=configs) == [3, 4]
|
||||
|
||||
configs = [
|
||||
{"tags": ["test"]},
|
||||
{"tags": ["test2"], "other": "test"},
|
||||
]
|
||||
assert await runnable.abatch([1, 2], config=configs) == [3, 4]
|
||||
|
||||
# Without config
|
||||
assert [x async for x in runnable.astream(1)] == [3]
|
||||
|
||||
# With Config
|
||||
assert [x async for x in runnable.astream(1, config={"tags": ["test"]})] == [3]
|
||||
|
||||
# With config and LC input data
|
||||
assert [
|
||||
x
|
||||
async for x in runnable.astream(
|
||||
HumanMessage(content="hello"), config={"tags": ["test"]}
|
||||
)
|
||||
] == [HumanMessage(content="hello")]
|
||||
|
||||
log_patches = [x async for x in runnable.astream_log(1)]
|
||||
for log_patch in log_patches:
|
||||
assert isinstance(log_patch, RunLogPatch)
|
||||
# Only check the first entry (not validating implementation here)
|
||||
first_op = log_patches[0].ops[0]
|
||||
assert first_op["op"] == "replace"
|
||||
assert first_op["path"] == ""
|
||||
|
||||
# Validate with HumanMessage
|
||||
log_patches = [x async for x in runnable.astream_log(HumanMessage(content="hello"))]
|
||||
for log_patch in log_patches:
|
||||
assert isinstance(log_patch, RunLogPatch)
|
||||
# Only check the first entry (not validating implementation here)
|
||||
first_op = log_patches[0].ops[0]
|
||||
assert first_op == {
|
||||
"op": "replace",
|
||||
"path": "",
|
||||
"value": {
|
||||
"final_output": None,
|
||||
"id": first_op["value"]["id"],
|
||||
"logs": [],
|
||||
"streamed_output": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_runnables(event_loop: AbstractEventLoop) -> None:
|
||||
"""Test serving multiple runnables."""
|
||||
|
||||
async def add_one(x: int) -> int:
|
||||
"""Add one to simulate a valid function"""
|
||||
return x + 1
|
||||
|
||||
async def mul_2(x: int) -> int:
|
||||
"""Add one to simulate a valid function"""
|
||||
return x * 2
|
||||
|
||||
app = FastAPI()
|
||||
add_routes(app, RunnableLambda(add_one), path="/add_one")
|
||||
add_routes(
|
||||
app,
|
||||
RunnableLambda(mul_2),
|
||||
input_type=int,
|
||||
path="/mul_2",
|
||||
)
|
||||
|
||||
async with get_async_client(app, path="/add_one") as runnable:
|
||||
async with get_async_client(app, path="/mul_2") as runnable2:
|
||||
assert await runnable.ainvoke(1) == 2
|
||||
assert await runnable2.ainvoke(4) == 8
|
||||
|
||||
composite_runnable = runnable | runnable2
|
||||
assert await composite_runnable.ainvoke(3) == 8
|
||||
|
||||
# Invoke runnable (remote add_one), local add_one, remote mul_2
|
||||
composite_runnable_2 = runnable | add_one | runnable2
|
||||
assert await composite_runnable_2.ainvoke(3) == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_validation(
|
||||
event_loop: AbstractEventLoop, mocker: MockerFixture
|
||||
) -> None:
|
||||
"""Test client side and server side exceptions."""
|
||||
|
||||
async def add_one(x: int) -> int:
|
||||
"""Add one to simulate a valid function"""
|
||||
return x + 1
|
||||
|
||||
server_runnable = RunnableLambda(func=add_one, afunc=add_one)
|
||||
server_runnable2 = RunnableLambda(func=add_one, afunc=add_one)
|
||||
|
||||
app = FastAPI()
|
||||
add_routes(
|
||||
app,
|
||||
server_runnable,
|
||||
input_type=int,
|
||||
path="/add_one",
|
||||
)
|
||||
|
||||
add_routes(
|
||||
app,
|
||||
server_runnable2,
|
||||
input_type=int,
|
||||
path="/add_one_config",
|
||||
config_keys=["tags", "run_name"],
|
||||
)
|
||||
|
||||
async with get_async_client(app, path="/add_one") as runnable:
|
||||
# Verify that can be invoked with valid input
|
||||
assert await runnable.ainvoke(1) == 2
|
||||
# Verify that the following substring is present in the error message
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await runnable.ainvoke("hello")
|
||||
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await runnable.abatch(["hello"])
|
||||
|
||||
config = {"tags": ["test"]}
|
||||
|
||||
invoke_spy_1 = mocker.spy(server_runnable, "ainvoke")
|
||||
# Verify config is handled correctly
|
||||
async with get_async_client(app, path="/add_one") as runnable1:
|
||||
# Verify that can be invoked with valid input
|
||||
# Config ignored for runnable1
|
||||
assert await runnable1.ainvoke(1, config=config) == 2
|
||||
assert invoke_spy_1.call_args[1]["config"] == {}
|
||||
|
||||
invoke_spy_2 = mocker.spy(server_runnable2, "ainvoke")
|
||||
async with get_async_client(app, path="/add_one_config") as runnable2:
|
||||
# Config accepted for runnable2
|
||||
assert await runnable2.ainvoke(1, config=config) == 2
|
||||
# Config ignored
|
||||
assert invoke_spy_2.call_args[1]["config"] == config
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_validation_with_lc_types(event_loop: AbstractEventLoop) -> None:
|
||||
"""Test client side and server side exceptions."""
|
||||
|
||||
app = FastAPI()
|
||||
# Test with langchain objects
|
||||
add_routes(
|
||||
app, RunnablePassthrough(), input_type=List[HumanMessage], config_keys=["tags"]
|
||||
)
|
||||
# Invoke request
|
||||
async with get_async_client(app) as passthrough_runnable:
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await passthrough_runnable.ainvoke("Hello")
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await passthrough_runnable.ainvoke(["hello"])
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await passthrough_runnable.ainvoke(HumanMessage(content="h"))
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await passthrough_runnable.ainvoke([SystemMessage(content="hello")])
|
||||
|
||||
# Valid
|
||||
result = await passthrough_runnable.ainvoke([HumanMessage(content="hello")])
|
||||
assert isinstance(result, list)
|
||||
assert isinstance(result[0], HumanMessage)
|
||||
|
||||
# Batch request
|
||||
async with get_async_client(app) as passthrough_runnable:
|
||||
# invalid
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await passthrough_runnable.abatch("Hello")
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await passthrough_runnable.abatch(["hello"])
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await passthrough_runnable.abatch([[SystemMessage(content="hello")]])
|
||||
|
||||
# valid
|
||||
result = await passthrough_runnable.abatch([[HumanMessage(content="hello")]])
|
||||
assert isinstance(result, list)
|
||||
assert isinstance(result[0], list)
|
||||
assert isinstance(result[0][0], HumanMessage)
|
||||
|
||||
|
||||
def test_client_close() -> None:
|
||||
"""Test that the client can be automatically."""
|
||||
runnable = RemoteRunnable(url="/dev/null", timeout=1)
|
||||
sync_client = runnable.sync_client
|
||||
async_client = runnable.async_client
|
||||
assert async_client.is_closed is False
|
||||
assert sync_client.is_closed is False
|
||||
del runnable
|
||||
assert sync_client.is_closed is True
|
||||
assert async_client.is_closed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_close() -> None:
|
||||
"""Test that the client can be automatically."""
|
||||
runnable = RemoteRunnable(url="/dev/null", timeout=1)
|
||||
sync_client = runnable.sync_client
|
||||
async_client = runnable.async_client
|
||||
assert async_client.is_closed is False
|
||||
assert sync_client.is_closed is False
|
||||
del runnable
|
||||
assert sync_client.is_closed is True
|
||||
assert async_client.is_closed is True
|
@ -0,0 +1,236 @@
|
||||
import typing
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
try:
|
||||
from pydantic.v1 import BaseModel, ValidationError
|
||||
except ImportError:
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from langserve.validation import (
|
||||
create_batch_request_model,
|
||||
create_invoke_request_model,
|
||||
create_runnable_config_model,
|
||||
replace_lc_object_types,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
{
|
||||
"input": {"a": "qqq"},
|
||||
"kwargs": {},
|
||||
"valid": False,
|
||||
},
|
||||
{
|
||||
"input": {"a": 2},
|
||||
"kwargs": "hello",
|
||||
"valid": False,
|
||||
},
|
||||
{
|
||||
"input": {"a": 2},
|
||||
"config": "hello",
|
||||
"valid": False,
|
||||
},
|
||||
{
|
||||
"input": {"b": "hello"},
|
||||
"valid": False,
|
||||
},
|
||||
{
|
||||
"input": {"a": 2, "b": "hello"},
|
||||
"config": "hello",
|
||||
"valid": False,
|
||||
},
|
||||
{
|
||||
"input": {"a": 2, "b": "hello"},
|
||||
"valid": True,
|
||||
},
|
||||
{
|
||||
"input": {"a": 2, "b": "hello"},
|
||||
"valid": True,
|
||||
},
|
||||
{
|
||||
"input": {"a": 2},
|
||||
"valid": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_create_invoke_and_batch_models(test_case: dict) -> None:
|
||||
"""Test that the invoke request model is created correctly."""
|
||||
|
||||
class Input(BaseModel):
|
||||
"""Test input."""
|
||||
|
||||
a: int
|
||||
b: Optional[str] = None
|
||||
|
||||
valid = test_case.pop("valid")
|
||||
config = create_runnable_config_model("test", ["tags"])
|
||||
|
||||
model = create_invoke_request_model("namespace", Input, config)
|
||||
|
||||
if valid:
|
||||
model(**test_case)
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
model(**test_case)
|
||||
|
||||
# Validate batch request
|
||||
# same structure as input request, but
|
||||
# 'input' is a list of inputs and is called 'inputs'
|
||||
batch_model = create_batch_request_model("namespace", Input, config)
|
||||
|
||||
test_case["inputs"] = [test_case.pop("input")]
|
||||
if valid:
|
||||
batch_model(**test_case)
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
batch_model(**test_case)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
{
|
||||
"type": int,
|
||||
"input": 1,
|
||||
"valid": True,
|
||||
},
|
||||
{
|
||||
"type": float,
|
||||
"input": "name",
|
||||
"valid": False,
|
||||
},
|
||||
{
|
||||
"type": float,
|
||||
"input": [3.2],
|
||||
"valid": False,
|
||||
},
|
||||
{
|
||||
"type": float,
|
||||
"input": 1.1,
|
||||
"valid": True,
|
||||
},
|
||||
{
|
||||
"type": Optional[float],
|
||||
"valid": True,
|
||||
"input": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_validation(test_case) -> None:
|
||||
"""Test that the invoke request model is created correctly."""
|
||||
config = create_runnable_config_model("test", [])
|
||||
model = create_invoke_request_model("namespace", test_case.pop("type"), config)
|
||||
|
||||
if test_case["valid"]:
|
||||
model(**test_case)
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
model(**test_case)
|
||||
|
||||
|
||||
def test_replace_lc_object_types() -> None:
|
||||
"""Replace lc object types in a model."""
|
||||
updated_type = replace_lc_object_types(typing.List[HumanMessage])
|
||||
config = create_runnable_config_model("test", [])
|
||||
invoke_request = create_invoke_request_model("namespace", updated_type, config)
|
||||
invoke_request(
|
||||
input=dumpd(
|
||||
[
|
||||
HumanMessage(content="Hello, world!"),
|
||||
HumanMessage(content="Hello, world 2!"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
invoke_request(input=[dumpd(AIMessage(content="Hello, world!"))])
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
invoke_request(
|
||||
input=dumpd(
|
||||
[
|
||||
AIMessage(content="Hello, world!"),
|
||||
HumanMessage(content="Hello, world!"),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_batch_request_with_lc_serialization() -> None:
|
||||
"""Test batch request with LC serialization."""
|
||||
|
||||
input_type = replace_lc_object_types(typing.List[HumanMessage])
|
||||
config = create_runnable_config_model("test", [])
|
||||
batch_request = create_batch_request_model("namespace", input_type, config)
|
||||
with pytest.raises(ValidationError):
|
||||
batch_request(inputs=dumpd([[SystemMessage(content="Hello, world!")]]))
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
batch_request(inputs=dumpd(HumanMessage(content="Hello, world!")))
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
batch_request(inputs=dumpd([HumanMessage(content="Hello, world!")]))
|
||||
|
||||
batch_request(inputs=dumpd([[HumanMessage(content="Hello, world!")]]))
|
||||
|
||||
|
||||
class PlaceHolderTypedDict(TypedDict):
|
||||
x: int
|
||||
z: HumanMessage
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"type_,input,is_valid",
|
||||
[
|
||||
(None, None, True),
|
||||
(str, "hello", True),
|
||||
(str, 123.0, True),
|
||||
(float, "qwe", False),
|
||||
(int, 1, True),
|
||||
(int, "qwe", False),
|
||||
(typing.Union[str, int], "hello", True),
|
||||
(typing.Union[str, int], 3, True),
|
||||
(typing.List[str], ["a", "b"], True),
|
||||
(typing.List[str], ["a", None], False),
|
||||
(typing.List[HumanMessage], [HumanMessage(content="hello, world!")], True),
|
||||
(typing.List[HumanMessage], [SystemMessage(content="hello, world!")], False),
|
||||
(
|
||||
typing.List[typing.Union[HumanMessage, SystemMessage]],
|
||||
[HumanMessage(content="he"), SystemMessage(content="hello, world!")],
|
||||
True,
|
||||
),
|
||||
(
|
||||
typing.List[typing.Union[HumanMessage, SystemMessage]],
|
||||
HumanMessage(content="hello"),
|
||||
False,
|
||||
),
|
||||
(
|
||||
typing.Union[
|
||||
typing.List[typing.Union[SystemMessage, HumanMessage, str]], str
|
||||
],
|
||||
["hello", "world"],
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_replace_lc_object_type(
|
||||
type_: typing.Any, input: typing.Any, is_valid: bool
|
||||
) -> None:
|
||||
"""Verify that code runs on different python versions."""
|
||||
new_type = replace_lc_object_types(type_)
|
||||
|
||||
class Model(BaseModel):
|
||||
input_: new_type
|
||||
|
||||
if is_valid:
|
||||
Model(input_=dumpd(input))
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
Model(input_=dumpd(input))
|
Loading…
Reference in New Issue