mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core: fix issue#24660, slove error messages about ValueError
when use model with history (#25183)
- **Description:** This PR will slove error messages about `ValueError` when use model with history. Detail in #24660. #22933 causes that `langchain_core.runnables.history.RunnableWithMessageHistory._get_output_messages` miss type check of `output_val` if `output_val` is `False`. After running `RunnableWithMessageHistory._is_not_async`, `output` is `False`.249945a572/libs/core/langchain_core/runnables/history.py (L323-L334)
15a36dd0a2/libs/core/langchain_core/runnables/history.py (L461-L471)
~~I suggest that `_get_output_messages` return empty list when `output_val == False`.~~ - **Issue**: - #24660 - **Dependencies:**: No Change. --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
ddd7919f6a
commit
d0ad713937
@ -16,7 +16,6 @@ from typing import (
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.load.load import load
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import RunnableBranch
|
||||
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.utils import (
|
||||
@ -320,17 +319,22 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
history_chain = RunnablePassthrough.assign(
|
||||
**{messages_key: history_chain}
|
||||
).with_config(run_name="insert_history")
|
||||
|
||||
runnable_sync: Runnable = runnable.with_listeners(on_end=self._exit_history)
|
||||
runnable_async: Runnable = runnable.with_alisteners(on_end=self._aexit_history)
|
||||
|
||||
def _call_runnable_sync(_input: Any) -> Runnable:
|
||||
return runnable_sync
|
||||
|
||||
async def _call_runnable_async(_input: Any) -> Runnable:
|
||||
return runnable_async
|
||||
|
||||
bound: Runnable = (
|
||||
history_chain
|
||||
| RunnableBranch(
|
||||
(
|
||||
RunnableLambda(
|
||||
self._is_not_async, afunc=self._is_async
|
||||
).with_config(run_name="RunnableWithMessageHistoryInAsyncMode"),
|
||||
runnable.with_alisteners(on_end=self._aexit_history),
|
||||
),
|
||||
runnable.with_listeners(on_end=self._exit_history),
|
||||
)
|
||||
| RunnableLambda(
|
||||
_call_runnable_sync,
|
||||
_call_runnable_async,
|
||||
).with_config(run_name="check_sync_or_async")
|
||||
).with_config(run_name="RunnableWithMessageHistory")
|
||||
|
||||
if history_factory_config:
|
||||
@ -468,7 +472,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
elif isinstance(output_val, (list, tuple)):
|
||||
return list(output_val)
|
||||
else:
|
||||
raise ValueError()
|
||||
raise ValueError(
|
||||
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
|
||||
f"Got {output_val}."
|
||||
)
|
||||
|
||||
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
|
||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||
|
@ -1,5 +1,7 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
@ -8,10 +10,12 @@ from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.runnables.base import RunnableBinding, RunnableLambda
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||
from langchain_core.tracers import Run
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
@ -724,3 +728,115 @@ def test_ignore_session_id() -> None:
|
||||
_ = with_message_history.invoke("hello")
|
||||
_ = with_message_history.invoke("hello again")
|
||||
assert len(history.messages) == 4
|
||||
|
||||
|
||||
class _RunnableLambdaWithRaiseError(RunnableLambda):
|
||||
from langchain_core.tracers.root_listeners import AsyncListener
|
||||
|
||||
def with_listeners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[
|
||||
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
||||
] = None,
|
||||
on_end: Optional[
|
||||
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
||||
] = None,
|
||||
on_error: Optional[
|
||||
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
||||
] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
from langchain_core.tracers.root_listeners import RootListenersTracer
|
||||
|
||||
def create_tracer(config: RunnableConfig) -> RunnableConfig:
|
||||
tracer = RootListenersTracer(
|
||||
config=config,
|
||||
on_start=on_start,
|
||||
on_end=on_end,
|
||||
on_error=on_error,
|
||||
)
|
||||
tracer.raise_error = True
|
||||
return {
|
||||
"callbacks": [tracer],
|
||||
}
|
||||
|
||||
return RunnableBinding(
|
||||
bound=self,
|
||||
config_factories=[lambda config: create_tracer(config)],
|
||||
)
|
||||
|
||||
def with_alisteners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[AsyncListener] = None,
|
||||
on_end: Optional[AsyncListener] = None,
|
||||
on_error: Optional[AsyncListener] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer
|
||||
|
||||
def create_tracer(config: RunnableConfig) -> RunnableConfig:
|
||||
tracer = AsyncRootListenersTracer(
|
||||
config=config,
|
||||
on_start=on_start,
|
||||
on_end=on_end,
|
||||
on_error=on_error,
|
||||
)
|
||||
tracer.raise_error = True
|
||||
return {
|
||||
"callbacks": [tracer],
|
||||
}
|
||||
|
||||
return RunnableBinding(
|
||||
bound=self,
|
||||
config_factories=[lambda config: create_tracer(config)],
|
||||
)
|
||||
|
||||
|
||||
def test_get_output_messages_no_value_error() -> None:
|
||||
runnable = _RunnableLambdaWithRaiseError(
|
||||
lambda messages: "you said: "
|
||||
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
|
||||
)
|
||||
store: Dict = {}
|
||||
get_session_history = _get_get_session_history(store=store)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
config: RunnableConfig = {
|
||||
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
|
||||
}
|
||||
may_catch_value_error = None
|
||||
try:
|
||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||
except ValueError as e:
|
||||
may_catch_value_error = e
|
||||
assert may_catch_value_error is None
|
||||
|
||||
|
||||
def test_get_output_messages_with_value_error() -> None:
|
||||
illegal_bool_message = False
|
||||
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_bool_message)
|
||||
store: Dict = {}
|
||||
get_session_history = _get_get_session_history(store=store)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
config: RunnableConfig = {
|
||||
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||
excepted = (
|
||||
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
|
||||
+ (" Got {}.".format(illegal_bool_message))
|
||||
)
|
||||
assert excepted in str(excinfo.value)
|
||||
|
||||
illegal_int_message = 123
|
||||
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||
excepted = (
|
||||
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
|
||||
+ (" Got {}.".format(illegal_int_message))
|
||||
)
|
||||
assert excepted in str(excinfo.value)
|
||||
|
Loading…
Reference in New Issue
Block a user