mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[patch]: Add unit test when catching generator exit (#23402)
This pr adds a unit test for: https://github.com/langchain-ai/langchain/pull/22662 And narrows the scope where the exception is caught.
This commit is contained in:
parent
5e6d23f27d
commit
da7beb1c38
@ -1878,7 +1878,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_output_supported = False
|
||||
else:
|
||||
final_output = chunk
|
||||
except StopIteration:
|
||||
except (StopIteration, GeneratorExit):
|
||||
pass
|
||||
for ichunk in input_for_tracing:
|
||||
if final_input_supported:
|
||||
@ -1892,8 +1892,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_input_supported = False
|
||||
else:
|
||||
final_input = ichunk
|
||||
except GeneratorExit:
|
||||
run_manager.on_chain_end(final_output, inputs=final_input)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e, inputs=final_input)
|
||||
raise
|
||||
|
@ -5706,3 +5706,39 @@ async def test_listeners_async() -> None:
|
||||
assert len(shared_state) == 2
|
||||
assert value1 in shared_state.values(), "Value not found in the dictionary."
|
||||
assert value2 in shared_state.values(), "Value not found in the dictionary."
|
||||
|
||||
|
||||
async def test_closing_iterator_doesnt_raise_error() -> None:
|
||||
"""Test that closing an iterator calls on_chain_end rather than on_chain_error."""
|
||||
import time
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
on_chain_error_triggered = False
|
||||
|
||||
class MyHandler(BaseCallbackHandler):
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
nonlocal on_chain_error_triggered
|
||||
on_chain_error_triggered = True
|
||||
|
||||
llm = GenericFakeChatModel(messages=iter(["hi there"]))
|
||||
chain = llm | StrOutputParser()
|
||||
chain_ = chain.with_config({"callbacks": [MyHandler()]})
|
||||
st = chain_.stream("hello")
|
||||
next(st)
|
||||
# This is a generator so close is defined on it.
|
||||
st.close() # type: ignore
|
||||
# Wait for a bit to make sure that the callback is called.
|
||||
time.sleep(0.05)
|
||||
assert on_chain_error_triggered is False
|
||||
|
Loading…
Reference in New Issue
Block a user