diff --git a/docs/modules/agents/streaming_stdout_final_only.ipynb b/docs/modules/agents/streaming_stdout_final_only.ipynb index 1746c8e1..c96b03e7 100644 --- a/docs/modules/agents/streaming_stdout_final_only.ipynb +++ b/docs/modules/agents/streaming_stdout_final_only.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "id": "23234b50-e6c6-4c87-9f97-259c15f36894", "metadata": { @@ -11,6 +12,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "29dd6333-307c-43df-b848-65001c01733b", "metadata": {}, @@ -36,6 +38,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "19a813f7", "metadata": {}, @@ -84,6 +87,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "53a743b8", "metadata": {}, @@ -92,11 +96,12 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "23602c62", "metadata": {}, "source": [ - "By default, we assume that the token sequence ``\"\\nFinal\", \" Answer\", \":\"`` indicates that the agent has reached an answers. We can, however, also pass a custom sequence to use as answer prefix." + "By default, we assume that the token sequence ``\"Final\", \"Answer\", \":\"`` indicates that the agent has reached an answers. We can, however, also pass a custom sequence to use as answer prefix." ] }, { @@ -108,26 +113,75 @@ "source": [ "llm = OpenAI(\n", " streaming=True,\n", - " callbacks=[FinalStreamingStdOutCallbackHandler(answer_prefix_tokens=[\"\\nThe\", \" answer\", \":\"])],\n", + " callbacks=[FinalStreamingStdOutCallbackHandler(answer_prefix_tokens=[\"The\", \"answer\", \":\"])],\n", " temperature=0\n", ")" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "b1a96cc0", "metadata": {}, "source": [ - "Be aware you likely need to include whitespaces and new line characters in your token. " + "For convenience, the callback automatically strips whitespaces and new line characters when comparing to `answer_prefix_tokens`. I.e., if `answer_prefix_tokens = [\"The\", \" answer\", \":\"]` then both `[\"\\nThe\", \" answer\", \":\"]` and `[\"The\", \" answer\", \":\"]` would be recognized a the answer prefix." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "9278b522", + "metadata": {}, + "source": [ + "If you don't know the tokenized version of your answer prefix, you can determine it with the following code:" ] }, { "cell_type": "code", "execution_count": null, - "id": "9278b522", + "id": "2f8f0640", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from langchain.callbacks.base import BaseCallbackHandler\n", + "\n", + "class MyCallbackHandler(BaseCallbackHandler):\n", + " def on_llm_new_token(self, token, **kwargs) -> None:\n", + " # print every token on a new line\n", + " print(f\"#{token}#\")\n", + "\n", + "llm = OpenAI(streaming=True, callbacks=[MyCallbackHandler()])\n", + "tools = load_tools([\"wikipedia\", \"llm-math\"], llm=llm)\n", + "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)\n", + "agent.run(\"It's 2023 now. How many years ago did Konrad Adenauer become Chancellor of Germany.\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "61190e58", + "metadata": {}, + "source": [ + "### Also streaming the answer prefixes" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1255776f", + "metadata": {}, + "source": [ + "When the parameter `stream_prefix = True` is set, the answer prefix itself will also be streamed. This can be useful when the answer prefix itself is part of the answer. For example, when your answer is a JSON like\n", + "\n", + "`\n", + "{\n", + " \"action\": \"Final answer\",\n", + " \"action_input\": \"Konrad Adenauer became Chancellor 74 years ago.\"\n", + "}\n", + "`\n", + "\n", + "and you don't only want the action_input to be streamed, but the entire JSON." + ] } ], "metadata": { diff --git a/langchain/callbacks/streaming_stdout_final_only.py b/langchain/callbacks/streaming_stdout_final_only.py index af992cfa..0527db35 100644 --- a/langchain/callbacks/streaming_stdout_final_only.py +++ b/langchain/callbacks/streaming_stdout_final_only.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -DEFAULT_ANSWER_PREFIX_TOKENS = ["\nFinal", " Answer", ":"] +DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"] class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): @@ -14,12 +14,51 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): Only the final output of the agent will be streamed. """ - def __init__(self, answer_prefix_tokens: Optional[List[str]] = None) -> None: + def append_to_last_tokens(self, token: str) -> None: + self.last_tokens.append(token) + self.last_tokens_stripped.append(token.strip()) + if len(self.last_tokens) > len(self.answer_prefix_tokens): + self.last_tokens.pop(0) + self.last_tokens_stripped.pop(0) + + def check_if_answer_reached(self) -> bool: + if self.strip_tokens: + return self.last_tokens_stripped == self.answer_prefix_tokens_stripped + else: + return self.last_tokens == self.answer_prefix_tokens + + def __init__( + self, + *, + answer_prefix_tokens: Optional[List[str]] = None, + strip_tokens: bool = True, + stream_prefix: bool = False + ) -> None: + """Instantiate FinalStreamingStdOutCallbackHandler. + + Args: + answer_prefix_tokens: Token sequence that prefixes the anwer. + Default is ["Final", "Answer", ":"] + strip_tokens: Ignore white spaces and new lines when comparing + answer_prefix_tokens to last tokens? (to determine if answer has been + reached) + stream_prefix: Should answer prefix itself also be streamed? + """ super().__init__() if answer_prefix_tokens is None: - answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS - self.answer_prefix_tokens = answer_prefix_tokens - self.last_tokens = [""] * len(answer_prefix_tokens) + self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS + else: + self.answer_prefix_tokens = answer_prefix_tokens + if strip_tokens: + self.answer_prefix_tokens_stripped = [ + token.strip() for token in self.answer_prefix_tokens + ] + else: + self.answer_prefix_tokens_stripped = self.answer_prefix_tokens + self.last_tokens = [""] * len(self.answer_prefix_tokens) + self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens) + self.strip_tokens = strip_tokens + self.stream_prefix = stream_prefix self.answer_reached = False def on_llm_start( @@ -32,15 +71,15 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): """Run on new LLM token. Only available when streaming is enabled.""" # Remember the last n tokens, where n = len(answer_prefix_tokens) - self.last_tokens.append(token) - if len(self.last_tokens) > len(self.answer_prefix_tokens): - self.last_tokens.pop(0) + self.append_to_last_tokens(token) # Check if the last n tokens match the answer_prefix_tokens list ... - if self.last_tokens == self.answer_prefix_tokens: + if self.check_if_answer_reached(): self.answer_reached = True - # Do not print the last token in answer_prefix_tokens, - # as it's not part of the answer yet + if self.stream_prefix: + for t in self.last_tokens: + sys.stdout.write(t) + sys.stdout.flush() return # ... if yes, then print tokens from now on