mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
router runnable (#8496)
Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
913a156cff
commit
5e3b968078
@ -22,10 +22,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 1,
|
||||
"id": "466b65b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"from langchain.chat_models import ChatOpenAI"
|
||||
@ -33,7 +42,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 2,
|
||||
"id": "3c634ef0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -583,6 +592,98 @@
|
||||
"chain2.invoke({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d094d637",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Router\n",
|
||||
"\n",
|
||||
"You can also use the router runnable to conditionally route inputs to different runnables."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "252625fd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import create_tagging_chain_pydantic\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"\n",
|
||||
"class PromptToUse(BaseModel):\n",
|
||||
" \"\"\"Used to determine which prompt to use to answer the user's input.\"\"\"\n",
|
||||
" \n",
|
||||
" name: str = Field(description=\"Should be one of `math` or `english`\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "57886e84",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tagger = create_tagging_chain_pydantic(PromptToUse, ChatOpenAI(temperature=0))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a303b089",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain1 = ChatPromptTemplate.from_template(\"You are a math genius. Answer the question: {question}\") | ChatOpenAI()\n",
|
||||
"chain2 = ChatPromptTemplate.from_template(\"You are an english major. Answer the question: {question}\") | ChatOpenAI()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "7aa9ea06",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema.runnable import RouterRunnable\n",
|
||||
"router = RouterRunnable({\"math\": chain1, \"english\": chain2})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "6a3d3f5d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = {\n",
|
||||
" \"key\": {\"input\": lambda x: x[\"question\"]} | tagger | (lambda x: x['text'].name),\n",
|
||||
" \"input\": {\"question\": lambda x: x[\"question\"]}\n",
|
||||
"} | router"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "8aeda930",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Thank you for the compliment! The sum of 2 + 2 is equal to 4.', additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({\"question\": \"whats 2 + 2\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "29781123",
|
||||
|
@ -108,6 +108,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> List[Output]:
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
return [self.invoke(inputs[0], configs[0])]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(executor.map(self.invoke, inputs, configs))
|
||||
|
||||
@ -759,6 +763,140 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
yield item
|
||||
|
||||
|
||||
class RouterInput(TypedDict):
|
||||
key: str
|
||||
input: Any
|
||||
|
||||
|
||||
class RouterRunnable(
|
||||
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
||||
):
|
||||
runnables: Mapping[str, Runnable[Input, Output]]
|
||||
|
||||
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
|
||||
super().__init__(runnables=runnables)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
Mapping[str, Any],
|
||||
],
|
||||
) -> RunnableSequence[RouterInput, Other]:
|
||||
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
Mapping[str, Any],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||
|
||||
def invoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return runnable.invoke(actual_input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return await runnable.ainvoke(actual_input, config)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(
|
||||
executor.map(
|
||||
lambda runnable, input, config: runnable.invoke(input, config),
|
||||
runnables,
|
||||
actual_inputs,
|
||||
configs,
|
||||
)
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
return await _gather_with_concurrency(
|
||||
max_concurrency,
|
||||
*(
|
||||
runnable.ainvoke(input, config)
|
||||
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
||||
),
|
||||
)
|
||||
|
||||
def stream(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
yield from runnable.stream(actual_input, config)
|
||||
|
||||
async def astream(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
async for output in runnable.astream(actual_input, config):
|
||||
yield output
|
||||
|
||||
|
||||
def _patch_config(
|
||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||
) -> RunnableConfig:
|
||||
|
File diff suppressed because one or more lines are too long
@ -23,6 +23,7 @@ from langchain.schema.document import Document
|
||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
RouterRunnable,
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
@ -572,6 +573,54 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
assert len(map_run.child_runs) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_router_runnable(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
chain1 = ChatPromptTemplate.from_template(
|
||||
"You are a math genius. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["4"])
|
||||
chain2 = ChatPromptTemplate.from_template(
|
||||
"You are an english major. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["2"])
|
||||
router = RouterRunnable({"math": chain1, "english": chain2})
|
||||
chain: Runnable = {
|
||||
"key": lambda x: x["key"],
|
||||
"input": {"question": lambda x: x["question"]},
|
||||
} | router
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
result = chain.invoke({"key": "math", "question": "2 + 2"})
|
||||
assert result == "4"
|
||||
|
||||
result2 = chain.batch(
|
||||
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||
)
|
||||
assert result2 == ["4", "2"]
|
||||
|
||||
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
|
||||
assert result == "4"
|
||||
|
||||
result2 = await chain.abatch(
|
||||
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||
)
|
||||
assert result2 == ["4", "2"]
|
||||
|
||||
# Test invoke
|
||||
router_spy = mocker.spy(router.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert (
|
||||
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
|
||||
== "4"
|
||||
)
|
||||
assert router_spy.call_args.args[1] == {
|
||||
"key": "math",
|
||||
"input": {"question": "2 + 2"},
|
||||
}
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
|
Loading…
Reference in New Issue
Block a user