forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
89 lines
3.3 KiB
Python
89 lines
3.3 KiB
Python
"""Callback Handler streams to stdout on new llm token."""
|
|
import sys
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
|
|
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
|
|
|
|
|
|
class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|
"""Callback handler for streaming in agents.
|
|
Only works with agents using LLMs that support streaming.
|
|
|
|
Only the final output of the agent will be streamed.
|
|
"""
|
|
|
|
def append_to_last_tokens(self, token: str) -> None:
|
|
self.last_tokens.append(token)
|
|
self.last_tokens_stripped.append(token.strip())
|
|
if len(self.last_tokens) > len(self.answer_prefix_tokens):
|
|
self.last_tokens.pop(0)
|
|
self.last_tokens_stripped.pop(0)
|
|
|
|
def check_if_answer_reached(self) -> bool:
|
|
if self.strip_tokens:
|
|
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
|
else:
|
|
return self.last_tokens == self.answer_prefix_tokens
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
answer_prefix_tokens: Optional[List[str]] = None,
|
|
strip_tokens: bool = True,
|
|
stream_prefix: bool = False
|
|
) -> None:
|
|
"""Instantiate FinalStreamingStdOutCallbackHandler.
|
|
|
|
Args:
|
|
answer_prefix_tokens: Token sequence that prefixes the anwer.
|
|
Default is ["Final", "Answer", ":"]
|
|
strip_tokens: Ignore white spaces and new lines when comparing
|
|
answer_prefix_tokens to last tokens? (to determine if answer has been
|
|
reached)
|
|
stream_prefix: Should answer prefix itself also be streamed?
|
|
"""
|
|
super().__init__()
|
|
if answer_prefix_tokens is None:
|
|
self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS
|
|
else:
|
|
self.answer_prefix_tokens = answer_prefix_tokens
|
|
if strip_tokens:
|
|
self.answer_prefix_tokens_stripped = [
|
|
token.strip() for token in self.answer_prefix_tokens
|
|
]
|
|
else:
|
|
self.answer_prefix_tokens_stripped = self.answer_prefix_tokens
|
|
self.last_tokens = [""] * len(self.answer_prefix_tokens)
|
|
self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens)
|
|
self.strip_tokens = strip_tokens
|
|
self.stream_prefix = stream_prefix
|
|
self.answer_reached = False
|
|
|
|
def on_llm_start(
|
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
) -> None:
|
|
"""Run when LLM starts running."""
|
|
self.answer_reached = False
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
|
|
|
# Remember the last n tokens, where n = len(answer_prefix_tokens)
|
|
self.append_to_last_tokens(token)
|
|
|
|
# Check if the last n tokens match the answer_prefix_tokens list ...
|
|
if self.check_if_answer_reached():
|
|
self.answer_reached = True
|
|
if self.stream_prefix:
|
|
for t in self.last_tokens:
|
|
sys.stdout.write(t)
|
|
sys.stdout.flush()
|
|
return
|
|
|
|
# ... if yes, then print tokens from now on
|
|
if self.answer_reached:
|
|
sys.stdout.write(token)
|
|
sys.stdout.flush()
|