mirror of https://github.com/hwchase17/langchain
feat: Add streaming only final aiter of agent (#6274)
<!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. Finally, we'd love to show appreciation for your contribution - if you'd like us to shout you out on Twitter, please also include your handle! --> <!-- Remove if not applicable --> #### Add streaming only final async iterator of agent This callback returns an async iterator and only streams the final output of an agent. <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> #### Who can review? Tag maintainers/contributors who might be interested: @agola11 <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @hwchase17 VectorStores / Retrievers / Memory - @dev2049 -->pull/6857/head
parent
1db266b20d
commit
d6cd0deaef
@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
|
||||
|
||||
|
||||
class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
"""Callback handler that returns an async iterator.
|
||||
Only the final output of the agent will be iterated.
|
||||
"""
|
||||
|
||||
def append_to_last_tokens(self, token: str) -> None:
|
||||
self.last_tokens.append(token)
|
||||
self.last_tokens_stripped.append(token.strip())
|
||||
if len(self.last_tokens) > len(self.answer_prefix_tokens):
|
||||
self.last_tokens.pop(0)
|
||||
self.last_tokens_stripped.pop(0)
|
||||
|
||||
def check_if_answer_reached(self) -> bool:
|
||||
if self.strip_tokens:
|
||||
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
||||
else:
|
||||
return self.last_tokens == self.answer_prefix_tokens
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
answer_prefix_tokens: Optional[List[str]] = None,
|
||||
strip_tokens: bool = True,
|
||||
stream_prefix: bool = False,
|
||||
) -> None:
|
||||
"""Instantiate AsyncFinalIteratorCallbackHandler.
|
||||
|
||||
Args:
|
||||
answer_prefix_tokens: Token sequence that prefixes the answer.
|
||||
Default is ["Final", "Answer", ":"]
|
||||
strip_tokens: Ignore white spaces and new lines when comparing
|
||||
answer_prefix_tokens to last tokens? (to determine if answer has been
|
||||
reached)
|
||||
stream_prefix: Should answer prefix itself also be streamed?
|
||||
"""
|
||||
super().__init__()
|
||||
if answer_prefix_tokens is None:
|
||||
self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS
|
||||
else:
|
||||
self.answer_prefix_tokens = answer_prefix_tokens
|
||||
if strip_tokens:
|
||||
self.answer_prefix_tokens_stripped = [
|
||||
token.strip() for token in self.answer_prefix_tokens
|
||||
]
|
||||
else:
|
||||
self.answer_prefix_tokens_stripped = self.answer_prefix_tokens
|
||||
self.last_tokens = [""] * len(self.answer_prefix_tokens)
|
||||
self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens)
|
||||
self.strip_tokens = strip_tokens
|
||||
self.stream_prefix = stream_prefix
|
||||
self.answer_reached = False
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
# If two calls are made in a row, this resets the state
|
||||
self.done.clear()
|
||||
self.answer_reached = False
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
if self.answer_reached:
|
||||
self.done.set()
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
# Remember the last n tokens, where n = len(answer_prefix_tokens)
|
||||
self.append_to_last_tokens(token)
|
||||
|
||||
# Check if the last n tokens match the answer_prefix_tokens list ...
|
||||
if self.check_if_answer_reached():
|
||||
self.answer_reached = True
|
||||
if self.stream_prefix:
|
||||
for t in self.last_tokens:
|
||||
self.queue.put_nowait(t)
|
||||
return
|
||||
|
||||
# If yes, then put tokens from now on
|
||||
if self.answer_reached:
|
||||
self.queue.put_nowait(token)
|
Loading…
Reference in New Issue