core: fix issue#24660, slove error messages about ValueError when use model with history (#25183)

- **Description:**
This PR will slove error messages about `ValueError` when use model with
history.
Detail in #24660.
#22933 causes that
`langchain_core.runnables.history.RunnableWithMessageHistory._get_output_messages`
miss type check of `output_val` if `output_val` is `False`. After
running `RunnableWithMessageHistory._is_not_async`, `output` is `False`.

249945a572/libs/core/langchain_core/runnables/history.py (L323-L334)

15a36dd0a2/libs/core/langchain_core/runnables/history.py (L461-L471)
~~I suggest that `_get_output_messages` return empty list when
`output_val == False`.~~

- **Issue**:
  - #24660

- **Dependencies:**: No Change.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Chengyu Yan 2024-08-14 22:26:22 +08:00 committed by GitHub
parent ddd7919f6a
commit d0ad713937
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 136 additions and 13 deletions

View File

@ -16,7 +16,6 @@ from typing import (
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableBranch
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
@ -320,17 +319,22 @@ class RunnableWithMessageHistory(RunnableBindingBase):
history_chain = RunnablePassthrough.assign(
**{messages_key: history_chain}
).with_config(run_name="insert_history")
runnable_sync: Runnable = runnable.with_listeners(on_end=self._exit_history)
runnable_async: Runnable = runnable.with_alisteners(on_end=self._aexit_history)
def _call_runnable_sync(_input: Any) -> Runnable:
return runnable_sync
async def _call_runnable_async(_input: Any) -> Runnable:
return runnable_async
bound: Runnable = (
history_chain
| RunnableBranch(
(
RunnableLambda(
self._is_not_async, afunc=self._is_async
).with_config(run_name="RunnableWithMessageHistoryInAsyncMode"),
runnable.with_alisteners(on_end=self._aexit_history),
),
runnable.with_listeners(on_end=self._exit_history),
)
| RunnableLambda(
_call_runnable_sync,
_call_runnable_async,
).with_config(run_name="check_sync_or_async")
).with_config(run_name="RunnableWithMessageHistory")
if history_factory_config:
@ -468,7 +472,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
elif isinstance(output_val, (list, tuple)):
return list(output_val)
else:
raise ValueError()
raise ValueError(
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
f"Got {output_val}."
)
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]

View File

@ -1,5 +1,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import pytest
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
@ -8,10 +10,12 @@ from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables import Runnable
from langchain_core.runnables.base import RunnableBinding, RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
from langchain_core.tracers import Run
from tests.unit_tests.pydantic_utils import _schema
@ -724,3 +728,115 @@ def test_ignore_session_id() -> None:
_ = with_message_history.invoke("hello")
_ = with_message_history.invoke("hello again")
assert len(history.messages) == 4
class _RunnableLambdaWithRaiseError(RunnableLambda):
from langchain_core.tracers.root_listeners import AsyncListener
def with_listeners(
self,
*,
on_start: Optional[
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
] = None,
on_end: Optional[
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
] = None,
on_error: Optional[
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
] = None,
) -> Runnable[Input, Output]:
from langchain_core.tracers.root_listeners import RootListenersTracer
def create_tracer(config: RunnableConfig) -> RunnableConfig:
tracer = RootListenersTracer(
config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
)
tracer.raise_error = True
return {
"callbacks": [tracer],
}
return RunnableBinding(
bound=self,
config_factories=[lambda config: create_tracer(config)],
)
def with_alisteners(
self,
*,
on_start: Optional[AsyncListener] = None,
on_end: Optional[AsyncListener] = None,
on_error: Optional[AsyncListener] = None,
) -> Runnable[Input, Output]:
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer
def create_tracer(config: RunnableConfig) -> RunnableConfig:
tracer = AsyncRootListenersTracer(
config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
)
tracer.raise_error = True
return {
"callbacks": [tracer],
}
return RunnableBinding(
bound=self,
config_factories=[lambda config: create_tracer(config)],
)
def test_get_output_messages_no_value_error() -> None:
runnable = _RunnableLambdaWithRaiseError(
lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
)
store: Dict = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
}
may_catch_value_error = None
try:
with_history.bound.invoke([HumanMessage(content="hello")], config)
except ValueError as e:
may_catch_value_error = e
assert may_catch_value_error is None
def test_get_output_messages_with_value_error() -> None:
illegal_bool_message = False
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_bool_message)
store: Dict = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
}
with pytest.raises(ValueError) as excinfo:
with_history.bound.invoke([HumanMessage(content="hello")], config)
excepted = (
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
+ (" Got {}.".format(illegal_bool_message))
)
assert excepted in str(excinfo.value)
illegal_int_message = 123
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
with pytest.raises(ValueError) as excinfo:
with_history.bound.invoke([HumanMessage(content="hello")], config)
excepted = (
"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]."
+ (" Got {}.".format(illegal_int_message))
)
assert excepted in str(excinfo.value)