core[patch]: Add unit test when catching generator exit (#23402)

This pr adds a unit test for:
https://github.com/langchain-ai/langchain/pull/22662
And narrows the scope where the exception is caught.
This commit is contained in:
Eugene Yurtsev 2024-06-27 16:36:07 -04:00 committed by GitHub
parent 5e6d23f27d
commit da7beb1c38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 3 deletions

View File

@ -1878,7 +1878,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output_supported = False
else:
final_output = chunk
except StopIteration:
except (StopIteration, GeneratorExit):
pass
for ichunk in input_for_tracing:
if final_input_supported:
@ -1892,8 +1892,6 @@ class Runnable(Generic[Input, Output], ABC):
final_input_supported = False
else:
final_input = ichunk
except GeneratorExit:
run_manager.on_chain_end(final_output, inputs=final_input)
except BaseException as e:
run_manager.on_chain_error(e, inputs=final_input)
raise

View File

@ -5706,3 +5706,39 @@ async def test_listeners_async() -> None:
assert len(shared_state) == 2
assert value1 in shared_state.values(), "Value not found in the dictionary."
assert value2 in shared_state.values(), "Value not found in the dictionary."
async def test_closing_iterator_doesnt_raise_error() -> None:
"""Test that closing an iterator calls on_chain_end rather than on_chain_error."""
import time
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.output_parsers import StrOutputParser
on_chain_error_triggered = False
class MyHandler(BaseCallbackHandler):
async def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain errors."""
nonlocal on_chain_error_triggered
on_chain_error_triggered = True
llm = GenericFakeChatModel(messages=iter(["hi there"]))
chain = llm | StrOutputParser()
chain_ = chain.with_config({"callbacks": [MyHandler()]})
st = chain_.stream("hello")
next(st)
# This is a generator so close is defined on it.
st.close() # type: ignore
# Wait for a bit to make sure that the callback is called.
time.sleep(0.05)
assert on_chain_error_triggered is False