router runnable (#8496)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
Harrison Chase 2023-07-31 11:07:10 -07:00 committed by GitHub
parent 913a156cff
commit 5e3b968078
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 512 additions and 3 deletions

View File

@ -22,10 +22,19 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"id": "466b65b3",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.chat_models import ChatOpenAI"
@ -33,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"id": "3c634ef0",
"metadata": {},
"outputs": [],
@ -583,6 +592,98 @@
"chain2.invoke({})"
]
},
{
"cell_type": "markdown",
"id": "d094d637",
"metadata": {},
"source": [
"## Router\n",
"\n",
"You can also use the router runnable to conditionally route inputs to different runnables."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "252625fd",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import create_tagging_chain_pydantic\n",
"from pydantic import BaseModel, Field\n",
"\n",
"class PromptToUse(BaseModel):\n",
" \"\"\"Used to determine which prompt to use to answer the user's input.\"\"\"\n",
" \n",
" name: str = Field(description=\"Should be one of `math` or `english`\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "57886e84",
"metadata": {},
"outputs": [],
"source": [
"tagger = create_tagging_chain_pydantic(PromptToUse, ChatOpenAI(temperature=0))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a303b089",
"metadata": {},
"outputs": [],
"source": [
"chain1 = ChatPromptTemplate.from_template(\"You are a math genius. Answer the question: {question}\") | ChatOpenAI()\n",
"chain2 = ChatPromptTemplate.from_template(\"You are an english major. Answer the question: {question}\") | ChatOpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7aa9ea06",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.runnable import RouterRunnable\n",
"router = RouterRunnable({\"math\": chain1, \"english\": chain2})"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6a3d3f5d",
"metadata": {},
"outputs": [],
"source": [
"chain = {\n",
" \"key\": {\"input\": lambda x: x[\"question\"]} | tagger | (lambda x: x['text'].name),\n",
" \"input\": {\"question\": lambda x: x[\"question\"]}\n",
"} | router"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8aeda930",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Thank you for the compliment! The sum of 2 + 2 is equal to 4.', additional_kwargs={}, example=False)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"question\": \"whats 2 + 2\"})"
]
},
{
"cell_type": "markdown",
"id": "29781123",

View File

@ -108,6 +108,10 @@ class Runnable(Generic[Input, Output], ABC):
) -> List[Output]:
configs = self._get_config_list(config, len(inputs))
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
return [self.invoke(inputs[0], configs[0])]
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(executor.map(self.invoke, inputs, configs))
@ -759,6 +763,140 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
yield item
class RouterInput(TypedDict):
key: str
input: Any
class RouterRunnable(
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
):
runnables: Mapping[str, Runnable[Input, Output]]
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
super().__init__(runnables=runnables)
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Any],
],
) -> RunnableSequence[RouterInput, Other]:
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Any],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return runnable.invoke(actual_input, config)
async def ainvoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return await runnable.ainvoke(actual_input, config)
def batch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(
executor.map(
lambda runnable, input, config: runnable.invoke(input, config),
runnables,
actual_inputs,
configs,
)
)
async def abatch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
return await _gather_with_concurrency(
max_concurrency,
*(
runnable.ainvoke(input, config)
for runnable, input, config in zip(runnables, actual_inputs, configs)
),
)
def stream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
yield from runnable.stream(actual_input, config)
async def astream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
async for output in runnable.astream(actual_input, config):
yield output
def _patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager
) -> RunnableConfig:

File diff suppressed because one or more lines are too long

View File

@ -23,6 +23,7 @@ from langchain.schema.document import Document
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (
RouterRunnable,
Runnable,
RunnableConfig,
RunnableLambda,
@ -572,6 +573,54 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
assert len(map_run.child_runs) == 2
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_router_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
chain1 = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
chain2 = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
router = RouterRunnable({"math": chain1, "english": chain2})
chain: Runnable = {
"key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]},
} | router
assert dumps(chain, pretty=True) == snapshot
result = chain.invoke({"key": "math", "question": "2 + 2"})
assert result == "4"
result2 = chain.batch(
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
)
assert result2 == ["4", "2"]
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
assert result == "4"
result2 = await chain.abatch(
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
)
assert result2 == ["4", "2"]
# Test invoke
router_spy = mocker.spy(router.__class__, "invoke")
tracer = FakeTracer()
assert (
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
== "4"
)
assert router_spy.call_args.args[1] == {
"key": "math",
"input": {"question": "2 + 2"},
}
assert tracer.runs == snapshot
@freeze_time("2023-01-01")
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
passthrough = mocker.Mock(side_effect=lambda x: x)