mirror of https://github.com/hwchase17/langchain
Runnable single protocol (#7800)
Objects implementing Runnable: BasePromptTemplate, LLM, ChatModel, Chain, Retriever, OutputParser - [x] Implement Runnable in base Retriever - [x] Raise TypeError in operator methods for unsupported things - [x] Implement dict which calls values in parallel and outputs dict with results - [x] Merge in `+` for prompts - [x] Confirm precedence order for operators, ideal would be `+` `|`, https://docs.python.org/3/reference/expressions.html#operator-precedence - [x] Add support for openai functions, ie. Chat Models must return messages - [x] Implement BaseMessageChunk return type for BaseChatModel, a subclass of BaseMessage which implements __add__ to return BaseMessageChunk, concatenating all str args - [x] Update implementation of stream/astream for llm and chat models to use new `_stream`, `_astream` optional methods, with default implementation in base class `raise NotImplementedError` use https://stackoverflow.com/a/59762827 to see if it is implemented in base class - [x] Delete the IteratorCallbackHandler (leave the async one because people using) - [x] Make BaseLLMOutputParser implement Runnable, accepting either str or BaseMessage --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>pull/8275/head^2
parent
04a4d3e312
commit
a612800ef0
@ -0,0 +1,705 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
|
||||
|
||||
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
async with semaphore:
|
||||
return await coro
|
||||
|
||||
|
||||
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
|
||||
if n is None:
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
semaphore = asyncio.Semaphore(n)
|
||||
|
||||
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
|
||||
|
||||
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
tags: List[str]
|
||||
"""
|
||||
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
You can use these to filter calls.
|
||||
"""
|
||||
|
||||
metadata: Dict[str, Any]
|
||||
"""
|
||||
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Keys should be strings, values should be JSON-serializable.
|
||||
"""
|
||||
|
||||
callbacks: Callbacks
|
||||
"""
|
||||
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||
"""
|
||||
|
||||
|
||||
Input = TypeVar("Input")
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
Output = TypeVar("Output")
|
||||
Other = TypeVar("Other")
|
||||
|
||||
|
||||
class Runnable(Generic[Input, Output], ABC):
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
...
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.invoke, input, config
|
||||
)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(executor.map(self.invoke, inputs, configs))
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
coros = map(self.ainvoke, inputs, configs)
|
||||
|
||||
return await _gather_with_concurrency(max_concurrency, *coros)
|
||||
|
||||
def stream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
yield self.invoke(input, config)
|
||||
|
||||
async def astream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
yield await self.ainvoke(input, config)
|
||||
|
||||
def _get_config_list(
|
||||
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||
) -> List[RunnableConfig]:
|
||||
if isinstance(config, list) and len(config) != length:
|
||||
raise ValueError(
|
||||
f"config must be a list of the same length as inputs, "
|
||||
f"but got {len(config)} configs for {length} inputs"
|
||||
)
|
||||
|
||||
return (
|
||||
config
|
||||
if isinstance(config, list)
|
||||
else [config.copy() if config is not None else {} for _ in range(length)]
|
||||
)
|
||||
|
||||
def _call_with_config(
|
||||
self,
|
||||
func: Callable[[Input], Output],
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
try:
|
||||
output = func(input)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
first: Runnable[Input, Any]
|
||||
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
|
||||
last: Runnable[Any, Output]
|
||||
|
||||
@property
|
||||
def steps(self) -> List[Runnable[Any, Any]]:
|
||||
return [self.first] + self.middle + [self.last]
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
first=self.first,
|
||||
middle=self.middle + [self.last] + other.middle,
|
||||
last=other.last,
|
||||
)
|
||||
else:
|
||||
return RunnableSequence(
|
||||
first=self.first,
|
||||
middle=self.middle + [self.last],
|
||||
last=_coerce_to_runnable(other),
|
||||
)
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
first=other.first,
|
||||
middle=other.middle + [other.last] + self.middle,
|
||||
last=self.last,
|
||||
)
|
||||
else:
|
||||
return RunnableSequence(
|
||||
first=_coerce_to_runnable(other),
|
||||
middle=[self.first] + self.middle,
|
||||
last=self.last,
|
||||
)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for step in self.steps:
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
input if isinstance(input, dict) else {"output": input}
|
||||
)
|
||||
return cast(Output, input)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for step in self.steps:
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
input if isinstance(input, dict) else {"output": input}
|
||||
)
|
||||
return cast(Output, input)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers = [
|
||||
cm.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
]
|
||||
|
||||
# invoke
|
||||
try:
|
||||
for step in self.steps:
|
||||
inputs = step.batch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
_patch_config(config, rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
# finish the root runs
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
for rm, input in zip(run_managers, inputs):
|
||||
rm.on_chain_end(input if isinstance(input, dict) else {"output": input})
|
||||
return cast(List[Output], inputs)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
)
|
||||
)
|
||||
|
||||
# invoke .batch() on each step
|
||||
# this uses batching optimizations in Runnable subclasses, like LLM
|
||||
try:
|
||||
for step in self.steps:
|
||||
inputs = await step.abatch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
_patch_config(config, rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
# finish the root runs
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||
raise
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
rm.on_chain_end(
|
||||
input if isinstance(input, dict) else {"output": input}
|
||||
)
|
||||
for rm, input in zip(run_managers, inputs)
|
||||
)
|
||||
)
|
||||
return cast(List[Output], inputs)
|
||||
|
||||
def stream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
# invoke the first steps
|
||||
try:
|
||||
for step in [self.first] + self.middle:
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
# stream the last step
|
||||
final: Union[Output, None] = None
|
||||
final_supported = True
|
||||
try:
|
||||
for output in self.last.stream(
|
||||
input,
|
||||
# mark the last step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
):
|
||||
yield output
|
||||
# Accumulate output if possible, otherwise disable accumulation
|
||||
if final_supported:
|
||||
if final is None:
|
||||
final = output
|
||||
else:
|
||||
try:
|
||||
final += output # type: ignore[operator]
|
||||
except TypeError:
|
||||
final = None
|
||||
final_supported = False
|
||||
pass
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
final if isinstance(final, dict) else {"output": final}
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
# invoke the first steps
|
||||
try:
|
||||
for step in [self.first] + self.middle:
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
# stream the last step
|
||||
final: Union[Output, None] = None
|
||||
final_supported = True
|
||||
try:
|
||||
async for output in self.last.astream(
|
||||
input,
|
||||
# mark the last step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
):
|
||||
yield output
|
||||
# Accumulate output if possible, otherwise disable accumulation
|
||||
if final_supported:
|
||||
if final is None:
|
||||
final = output
|
||||
else:
|
||||
try:
|
||||
final += output # type: ignore[operator]
|
||||
except TypeError:
|
||||
final = None
|
||||
final_supported = False
|
||||
pass
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
final if isinstance(final, dict) else {"output": final}
|
||||
)
|
||||
|
||||
|
||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
steps: Dict[str, Runnable[Input, Any]]
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||
steps = self.steps.copy()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
step.invoke,
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
for step in steps.values()
|
||||
]
|
||||
output = {key: future.result() for key, future in zip(steps, futures)}
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(output)
|
||||
return output
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), {"input": input}
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||
steps = self.steps.copy()
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
for step in steps.values()
|
||||
)
|
||||
)
|
||||
output = {key: value for key, value in zip(steps, results)}
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(output)
|
||||
return output
|
||||
|
||||
|
||||
class RunnableLambda(Runnable[Input, Output]):
|
||||
def __init__(self, func: Callable[[Input], Output]) -> None:
|
||||
if callable(func):
|
||||
self.func = func
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected a callable type for `func`."
|
||||
f"Instead got an unsupported type: {type(func)}"
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, RunnableLambda):
|
||||
return self.func == other.func
|
||||
else:
|
||||
return False
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
return self._call_with_config(self.func, input, config)
|
||||
|
||||
|
||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
return self._call_with_config(lambda x: x, input, config)
|
||||
|
||||
|
||||
def _patch_config(
|
||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||
) -> RunnableConfig:
|
||||
config = config.copy()
|
||||
config["callbacks"] = callback_manager
|
||||
return config
|
||||
|
||||
|
||||
def _coerce_to_runnable(
|
||||
thing: Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
Dict[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
|
||||
]
|
||||
) -> Runnable[Input, Output]:
|
||||
if isinstance(thing, Runnable):
|
||||
return thing
|
||||
elif callable(thing):
|
||||
return RunnableLambda(thing)
|
||||
elif isinstance(thing, dict):
|
||||
runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()}
|
||||
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected a Runnable, callable or dict."
|
||||
f"Instead got an unsupported type: {type(thing)}"
|
||||
)
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,38 @@
|
||||
from langchain.schema.messages import AIMessageChunk, HumanMessageChunk
|
||||
|
||||
|
||||
def test_message_chunks() -> None:
|
||||
assert AIMessageChunk(content="I am") + AIMessageChunk(
|
||||
content=" indeed."
|
||||
) == AIMessageChunk(
|
||||
content="I am indeed."
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk"
|
||||
|
||||
assert AIMessageChunk(content="I am") + HumanMessageChunk(
|
||||
content=" indeed."
|
||||
) == AIMessageChunk(
|
||||
content="I am indeed."
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
|
||||
|
||||
assert AIMessageChunk(
|
||||
content="", additional_kwargs={"foo": "bar"}
|
||||
) + AIMessageChunk(content="", additional_kwargs={"baz": "foo"}) == AIMessageChunk(
|
||||
content="", additional_kwargs={"foo": "bar", "baz": "foo"}
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
||||
|
||||
assert AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "web_search"}}
|
||||
) + AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": "{\n"}}
|
||||
) + AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"arguments": ' "query": "turtles"\n}'}},
|
||||
) == AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "web_search",
|
||||
"arguments": '{\n "query": "turtles"\n}',
|
||||
}
|
||||
},
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
@ -0,0 +1,52 @@
|
||||
from langchain.schema.messages import HumanMessageChunk
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
|
||||
def test_generation_chunk() -> None:
|
||||
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
||||
text="world!"
|
||||
) == GenerationChunk(
|
||||
text="Hello, world!"
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk"
|
||||
|
||||
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
||||
text="world!", generation_info={"foo": "bar"}
|
||||
) == GenerationChunk(
|
||||
text="Hello, world!", generation_info={"foo": "bar"}
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||
|
||||
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
||||
text="world!", generation_info={"foo": "bar"}
|
||||
) + GenerationChunk(text="!", generation_info={"baz": "foo"}) == GenerationChunk(
|
||||
text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"}
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||
|
||||
|
||||
def test_chat_generation_chunk() -> None:
|
||||
assert ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, ")
|
||||
) + ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="world!")
|
||||
) == ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, world!")
|
||||
), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk"
|
||||
|
||||
assert ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, ")
|
||||
) + ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
|
||||
) == ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, world!"),
|
||||
generation_info={"foo": "bar"},
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||
|
||||
assert ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, ")
|
||||
) + ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
|
||||
) + ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"}
|
||||
) == ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, world!!"),
|
||||
generation_info={"foo": "bar", "baz": "foo"},
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
@ -0,0 +1,547 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
ChatPromptValue,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
RunnableSequence,
|
||||
)
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Run] = []
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
self.runs.append(run)
|
||||
|
||||
|
||||
class FakeRunnable(Runnable[str, int]):
|
||||
def invoke(
|
||||
self,
|
||||
input: str,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> int:
|
||||
return len(input)
|
||||
|
||||
|
||||
class FakeRetriever(BaseRetriever):
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fixed_uuids(mocker: MockerFixture) -> MockerFixture._Patcher:
|
||||
"""Note this mock only works with `import uuid; uuid.uuid4()`,
|
||||
it does not work with `from uuid import uuid4; uuid4()`."""
|
||||
|
||||
# Disable tracing to avoid fixed UUIDs causing tracing errors.
|
||||
mocker.patch.dict("os.environ", {"LANGCHAIN_TRACING_V2": "false"})
|
||||
|
||||
side_effect = (
|
||||
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
|
||||
)
|
||||
return mocker.patch("uuid.uuid4", side_effect=side_effect)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
fake = FakeRunnable()
|
||||
spy = mocker.spy(fake, "invoke")
|
||||
|
||||
assert fake.invoke("hello", dict(tags=["a-tag"])) == 5
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(tags=["a-tag"])),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert [*fake.stream("hello", dict(metadata={"key": "value"}))] == [5]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(metadata={"key": "value"})),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert fake.batch(
|
||||
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
||||
) == [5, 7]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(tags=["a-tag"])),
|
||||
mocker.call("wooorld", dict(metadata={"key": "value"})),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(tags=["a-tag"])),
|
||||
mocker.call("wooorld", dict(tags=["a-tag"])),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(callbacks=[])),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert [part async for part in fake.astream("hello")] == [5]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", None),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert await fake.abatch(["hello", "wooorld"], dict(metadata={"key": "value"})) == [
|
||||
5,
|
||||
7,
|
||||
]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(metadata={"key": "value"})),
|
||||
mocker.call("wooorld", dict(metadata={"key": "value"})),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt() -> None:
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
]
|
||||
)
|
||||
expected = ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
|
||||
assert prompt.invoke({"question": "What is your name?"}) == expected
|
||||
|
||||
assert prompt.batch(
|
||||
[
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
]
|
||||
) == [
|
||||
expected,
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your favorite color?"),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
assert [*prompt.stream({"question": "What is your name?"})] == [expected]
|
||||
|
||||
assert await prompt.ainvoke({"question": "What is your name?"}) == expected
|
||||
|
||||
assert await prompt.abatch(
|
||||
[
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
]
|
||||
) == [
|
||||
expected,
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your favorite color?"),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
assert [
|
||||
part async for part in prompt.astream({"question": "What is your name?"})
|
||||
] == [expected]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_chat_model(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo", "bar"])
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == []
|
||||
assert chain.last == chat
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == AIMessage(content="foo")
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(chat_spy)
|
||||
|
||||
# Test batch
|
||||
prompt_spy = mocker.spy(prompt.__class__, "batch")
|
||||
chat_spy = mocker.spy(chat.__class__, "batch")
|
||||
tracer = FakeTracer()
|
||||
assert chain.batch(
|
||||
[
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [
|
||||
AIMessage(content="bar"),
|
||||
AIMessage(content="foo"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
]
|
||||
assert chat_spy.call_args.args[1] == [
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your favorite color?"),
|
||||
]
|
||||
),
|
||||
]
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(chat_spy)
|
||||
|
||||
# Test stream
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "stream")
|
||||
tracer = FakeTracer()
|
||||
assert [
|
||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||
] == [AIMessage(content="bar")]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_llm(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeListLLM(responses=["foo", "bar"])
|
||||
|
||||
chain = prompt | llm
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == []
|
||||
assert chain.last == llm
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "ainvoke")
|
||||
tracer = FakeTracer()
|
||||
assert (
|
||||
await chain.ainvoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
== "foo"
|
||||
)
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(llm_spy)
|
||||
|
||||
# Test batch
|
||||
prompt_spy = mocker.spy(prompt.__class__, "abatch")
|
||||
llm_spy = mocker.spy(llm.__class__, "abatch")
|
||||
tracer = FakeTracer()
|
||||
assert await chain.abatch(
|
||||
[
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == ["bar", "foo"]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
]
|
||||
assert llm_spy.call_args.args[1] == [
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your favorite color?"),
|
||||
]
|
||||
),
|
||||
]
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(llm_spy)
|
||||
|
||||
# Test stream
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "astream")
|
||||
tracer = FakeTracer()
|
||||
assert [
|
||||
token
|
||||
async for token in chain.astream(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
] == ["bar"]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_prompt_with_chat_model_and_parser(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo, bar"])
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
|
||||
chain = prompt | chat | parser
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [chat]
|
||||
assert chain.last == parser
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||
parser_spy = mocker.spy(parser.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == ["foo", "bar"]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_dict_prompt_llm(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
|
||||
retriever = FakeRetriever()
|
||||
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ """Context:
|
||||
{documents}
|
||||
|
||||
Question:
|
||||
{question}"""
|
||||
)
|
||||
|
||||
chat = FakeListChatModel(responses=["foo, bar"])
|
||||
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
|
||||
chain = (
|
||||
{
|
||||
"question": RunnablePassthrough[str]() | passthrough,
|
||||
"documents": passthrough | retriever,
|
||||
"just_to_test_lambda": passthrough,
|
||||
}
|
||||
| prompt
|
||||
| chat
|
||||
| parser
|
||||
)
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert isinstance(chain.first, RunnableMap)
|
||||
assert chain.middle == [prompt, chat]
|
||||
assert chain.last == parser
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||
parser_spy = mocker.spy(parser.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke("What is your name?", dict(callbacks=[tracer])) == [
|
||||
"foo",
|
||||
"bar",
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == {
|
||||
"documents": [Document(page_content="foo"), Document(page_content="bar")],
|
||||
"question": "What is your name?",
|
||||
"just_to_test_lambda": "What is your name?",
|
||||
}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(
|
||||
content="""Context:
|
||||
[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]
|
||||
|
||||
Question:
|
||||
What is your name?"""
|
||||
),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_prompt_dict(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
|
||||
chat = FakeListChatModel(responses=["i'm a chatbot"])
|
||||
|
||||
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||
|
||||
chain = (
|
||||
prompt
|
||||
| passthrough
|
||||
| { # type: ignore
|
||||
"chat": chat,
|
||||
"llm": llm,
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [RunnableLambda(passthrough)]
|
||||
assert isinstance(chain.last, RunnableMap)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == {
|
||||
"chat": AIMessage(content="i'm a chatbot"),
|
||||
"llm": "i'm a textbot",
|
||||
}
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert tracer.runs == snapshot
|
Loading…
Reference in New Issue