Enable streaming for OpenAI LLM (#986)

* Support a callback `on_llm_new_token` that users can implement when
`OpenAI.streaming` is set to `True`
This commit is contained in:
Ankush Gola 2023-02-14 15:06:14 -08:00 committed by GitHub
parent f05f025e41
commit caa8e4742e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1311 additions and 155 deletions

View File

@ -14,7 +14,9 @@
"cell_type": "code",
"execution_count": 1,
"id": "70c4e529",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
@ -36,7 +38,9 @@
"cell_type": "code",
"execution_count": 2,
"id": "01c46e92",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
@ -56,7 +60,9 @@
"cell_type": "code",
"execution_count": 3,
"id": "433363a5",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# loaders = [....]\n",
@ -75,9 +81,11 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "a8930cf7",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
@ -106,9 +114,11 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"id": "7b4110f3",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore)"
@ -126,7 +136,9 @@
"cell_type": "code",
"execution_count": 6,
"id": "7fe3e730",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat_history = []\n",
@ -136,9 +148,11 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"id": "bfff9cc8",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
@ -146,7 +160,7 @@
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
]
},
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -165,9 +179,11 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"id": "00b4cf00",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat_history = [(query, result[\"answer\"])]\n",
@ -177,9 +193,11 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"id": "f01828d1",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
@ -187,7 +205,7 @@
"' Justice Stephen Breyer'"
]
},
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -196,10 +214,90 @@
"result['answer']"
]
},
{
"cell_type": "markdown",
"id": "2324cdc6-98bf-4708-b8cd-02a98b1e5b67",
"metadata": {},
"source": [
"## Chat Vector DB with streaming to `stdout`\n",
"\n",
"Output from the chain will be streamed to `stdout` token by token in this example."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2efacec3-2690-4b05-8de3-a32fd2ac3911",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chains.llm import LLMChain\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"\n",
"# Construct a ChatVectorDBChain with a streaming llm for combine docs\n",
"# and a separate, non-streaming llm for question generation\n",
"llm = OpenAI(temperature=0)\n",
"streaming_llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
"\n",
"question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)\n",
"doc_chain = load_qa_chain(streaming_llm, chain_type=\"stuff\", prompt=QA_PROMPT)\n",
"\n",
"qa = ChatVectorDBChain(vectorstore=vectorstore, combine_docs_chain=doc_chain, question_generator=question_generator)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fd6d43f4-7428-44a4-81bc-26fe88a98762",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans."
]
}
],
"source": [
"chat_history = []\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"result = qa({\"question\": query, \"chat_history\": chat_history})"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "5ab38978-f3e8-4fa7-808c-c79dec48379a",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Justice Stephen Breyer"
]
}
],
"source": [
"chat_history = [(query, result[\"answer\"])]\n",
"query = \"Did he mention who she suceeded\"\n",
"result = qa({\"question\": query, \"chat_history\": chat_history})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0f869c6",
"id": "a7ea93ff-1899-4171-9c24-85df20ae1a3d",
"metadata": {},
"outputs": [],
"source": []
@ -221,7 +319,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.9"
}
},
"nbformat": 4,

View File

@ -18,7 +18,9 @@
"cell_type": "code",
"execution_count": 1,
"id": "df924055",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.llms import OpenAI"
@ -207,14 +209,6 @@
"source": [
"llm.get_num_tokens(\"what a joke\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b004ffdd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@ -8,6 +8,7 @@ They are split into two categories:
1. `Generic Functionality <./generic_how_to.html>`_: Covering generic functionality all LLMs should have.
2. `Integrations <./integrations.html>`_: Covering integrations with various LLM providers.
3. `Asynchronous <./async_llm.html>`_: Covering asynchronous functionality.
4. `Streaming <./streaming_llm.html>`_: Covering streaming functionality.
.. toctree::
:maxdepth: 1

View File

@ -0,0 +1,140 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "6eaf7e66-f49c-42da-8d11-22ea13bef718",
"metadata": {},
"source": [
"# Streaming with LLMs\n",
"\n",
"LangChain provides streaming support for LLMs. Currently, we only support streaming for the `OpenAI` LLM implementation, but streaming support for other LLM implementations is on the roadmap. To utilize streaming, use a [`CallbackHandler`](https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/base.py) that implements `on_llm_new_token`. In this example, we are using [`StreamingStdOutCallbackHandler`]()."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Verse 1\n",
"I'm sippin' on sparkling water,\n",
"It's so refreshing and light,\n",
"It's the perfect way to quench my thirst,\n",
"On a hot summer night.\n",
"\n",
"Chorus\n",
"Sparkling water, sparkling water,\n",
"It's the best way to stay hydrated,\n",
"It's so refreshing and light,\n",
"It's the perfect way to stay alive.\n",
"\n",
"Verse 2\n",
"I'm sippin' on sparkling water,\n",
"It's so bubbly and bright,\n",
"It's the perfect way to cool me down,\n",
"On a hot summer night.\n",
"\n",
"Chorus\n",
"Sparkling water, sparkling water,\n",
"It's the best way to stay hydrated,\n",
"It's so refreshing and light,\n",
"It's the perfect way to stay alive.\n",
"\n",
"Verse 3\n",
"I'm sippin' on sparkling water,\n",
"It's so crisp and clean,\n",
"It's the perfect way to keep me going,\n",
"On a hot summer day.\n",
"\n",
"Chorus\n",
"Sparkling water, sparkling water,\n",
"It's the best way to stay hydrated,\n",
"It's so refreshing and light,\n",
"It's the perfect way to stay alive."
]
}
],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"\n",
"\n",
"llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
"resp = llm(\"Write me a song about sparkling water.\")"
]
},
{
"cell_type": "markdown",
"id": "61fb6de7-c6c8-48d0-a48e-1204c027a23c",
"metadata": {
"tags": []
},
"source": [
"We still have access to the end `LLMResult` if using `generate`. However, `token_usage` is not currently supported for streaming."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a35373f1-9ee6-4753-a343-5aee749b8527",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Q: What did the fish say when it hit the wall?\n",
"A: Dam!"
]
},
{
"data": {
"text/plain": [
"LLMResult(generations=[[Generation(text='\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!', generation_info={'finish_reason': 'stop', 'logprobs': None})]], llm_output={'token_usage': {}})"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.generate([\"Tell me a joke.\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -375,6 +375,22 @@ class AgentExecutor(Chain, BaseModel):
final_output["intermediate_steps"] = intermediate_steps
return final_output
async def _areturn(
self, output: AgentFinish, intermediate_steps: list
) -> Dict[str, Any]:
if self.callback_manager.is_async:
await self.callback_manager.on_agent_finish(
output, color="green", verbose=self.verbose
)
else:
self.callback_manager.on_agent_finish(
output, color="green", verbose=self.verbose
)
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
return final_output
def _take_next_step(
self,
name_to_tool_map: Dict[str, Tool],
@ -428,6 +444,90 @@ class AgentExecutor(Chain, BaseModel):
return AgentFinish({self.agent.return_values[0]: observation}, "")
return output, observation
async def _atake_next_step(
self,
name_to_tool_map: Dict[str, Tool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
) -> Union[AgentFinish, Tuple[AgentAction, str]]:
"""Take a single step in the thought-action-observation loop.
Override this to take control of how the agent makes and acts on choices.
"""
# Call the LLM to see what to do.
output = await self.agent.aplan(intermediate_steps, **inputs)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
# Otherwise we lookup the tool
if output.tool in name_to_tool_map:
tool = name_to_tool_map[output.tool]
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": str(tool.func)[:60] + "..."},
output,
verbose=self.verbose,
)
else:
self.callback_manager.on_tool_start(
{"name": str(tool.func)[:60] + "..."},
output,
verbose=self.verbose,
)
try:
# We then call the tool on the tool input to get an observation
observation = (
await tool.coroutine(output.tool_input)
if tool.coroutine
# If the tool is not a coroutine, we run it in the executor
# to avoid blocking the event loop.
else await asyncio.get_event_loop().run_in_executor(
None, tool.func, output.tool_input
)
)
color = color_mapping[output.tool]
return_direct = tool.return_direct
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=self.verbose)
else:
self.callback_manager.on_tool_error(e, verbose=self.verbose)
raise e
else:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": "N/A"}, output, verbose=self.verbose
)
else:
self.callback_manager.on_tool_start(
{"name": "N/A"}, output, verbose=self.verbose
)
observation = f"{output.tool} is not a valid tool, try another one."
color = None
return_direct = False
llm_prefix = "" if return_direct else self.agent.llm_prefix
if self.callback_manager.is_async:
await self.callback_manager.on_tool_end(
observation,
color=color,
observation_prefix=self.agent.observation_prefix,
llm_prefix=llm_prefix,
verbose=self.verbose,
)
else:
self.callback_manager.on_tool_end(
observation,
color=color,
observation_prefix=self.agent.observation_prefix,
llm_prefix=llm_prefix,
verbose=self.verbose,
)
if return_direct:
# Set the log to "" because we do not want to log it.
return AgentFinish({self.agent.return_values[0]: observation}, "")
return output, observation
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Make sure that every tool is synchronous (not a coroutine)
@ -486,58 +586,15 @@ class AgentExecutor(Chain, BaseModel):
iterations = 0
# We now enter the agent loop (until it returns something).
while self._should_continue(iterations):
# Call the LLM to see what to do.
output = await self.agent.aplan(intermediate_steps, **inputs)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return self._return(output, intermediate_steps)
next_step_output = await self._atake_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
)
if isinstance(next_step_output, AgentFinish):
return await self._areturn(next_step_output, intermediate_steps)
# Otherwise we lookup the tool
if output.tool in name_to_tool_map:
tool = name_to_tool_map[output.tool]
self.callback_manager.on_tool_start(
{"name": str(tool.func)[:60] + "..."},
output,
verbose=self.verbose,
)
try:
# We then call the tool on the tool input to get an observation
observation = (
await tool.coroutine(output.tool_input)
if tool.coroutine
# If the tool is not a coroutine, we run it in the executor
# to avoid blocking the event loop.
else await asyncio.get_event_loop().run_in_executor(
None, tool.func, output.tool_input
)
)
color = color_mapping[output.tool]
return_direct = tool.return_direct
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_tool_error(e, verbose=self.verbose)
raise e
else:
self.callback_manager.on_tool_start(
{"name": "N/A"}, output, verbose=self.verbose
)
observation = f"{output.tool} is not a valid tool, try another one."
color = None
return_direct = False
llm_prefix = "" if return_direct else self.agent.llm_prefix
self.callback_manager.on_tool_end(
observation,
color=color,
observation_prefix=self.agent.observation_prefix,
llm_prefix=llm_prefix,
verbose=self.verbose,
)
intermediate_steps.append((output, observation))
if return_direct:
# Set the log to "" because we do not want to log it.
output = AgentFinish({self.agent.return_values[0]: observation}, "")
return self._return(output, intermediate_steps)
intermediate_steps.append(next_step_output)
iterations += 1
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return self._return(output, intermediate_steps)
return await self._areturn(output, intermediate_steps)

View File

@ -1,5 +1,6 @@
"""Base callback handler that can be used to handle callbacks from langchain."""
import asyncio
import functools
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union
@ -32,63 +33,72 @@ class BaseCallbackHandler(ABC):
@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
) -> Any:
"""Run when LLM starts running."""
@abstractmethod
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
@abstractmethod
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
"""Run when LLM ends running."""
@abstractmethod
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
) -> Any:
"""Run when LLM errors."""
@abstractmethod
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
) -> Any:
"""Run when chain starts running."""
@abstractmethod
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""
@abstractmethod
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
) -> Any:
"""Run when chain errors."""
@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
) -> Any:
"""Run when tool starts running."""
@abstractmethod
def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
@abstractmethod
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
) -> Any:
"""Run when tool errors."""
@abstractmethod
def on_text(self, text: str, **kwargs: Any) -> None:
def on_text(self, text: str, **kwargs: Any) -> Any:
"""Run on arbitrary text."""
@abstractmethod
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
class BaseCallbackManager(BaseCallbackHandler, ABC):
"""Base callback manager that can be used to handle callbacks from LangChain."""
@property
def is_async(self) -> bool:
"""Whether the callback manager is async."""
return False
@abstractmethod
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
@ -126,6 +136,15 @@ class CallbackManager(BaseCallbackManager):
if verbose or handler.always_verbose:
handler.on_llm_start(serialized, prompts, **kwargs)
def on_llm_new_token(
self, token: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM generates a new token."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_new_token(token, **kwargs)
def on_llm_end(
self, response: LLMResult, verbose: bool = False, **kwargs: Any
) -> None:
@ -239,3 +258,287 @@ class CallbackManager(BaseCallbackManager):
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = handlers
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
async def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
async def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
async def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
async def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
async def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
async def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
class AsyncCallbackManager(BaseCallbackManager):
"""Async callback manager that can be used to handle callbacks from LangChain."""
@property
def is_async(self) -> bool:
"""Return whether the handler is async."""
return True
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_start):
await handler.on_llm_start(serialized, prompts, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_llm_start, serialized, prompts, **kwargs
),
)
async def on_llm_new_token(
self, token: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_new_token):
await handler.on_llm_new_token(token, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_llm_new_token, token, **kwargs
),
)
async def on_llm_end(
self, response: LLMResult, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM ends running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_end):
await handler.on_llm_end(response, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_llm_end, response, **kwargs),
)
async def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM errors."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_error):
await handler.on_llm_error(error, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_llm_error, error, **kwargs),
)
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain starts running."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_chain_start):
await handler.on_chain_start(serialized, inputs, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_chain_start, serialized, inputs, **kwargs
),
)
async def on_chain_end(
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
) -> None:
"""Run when chain ends running."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_chain_end):
await handler.on_chain_end(outputs, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_chain_end, outputs, **kwargs),
)
async def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain errors."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_chain_error):
await handler.on_chain_error(error, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_chain_error, error, **kwargs),
)
async def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_tool_start):
await handler.on_tool_start(serialized, action, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_tool_start, serialized, action, **kwargs
),
)
async def on_tool_end(
self, output: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when tool ends running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_tool_end):
await handler.on_tool_end(output, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_tool_end, output, **kwargs),
)
async def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool errors."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_tool_error):
await handler.on_tool_error(error, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_tool_error, error, **kwargs),
)
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when text is printed."""
for handler in self.handlers:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_text):
await handler.on_text(text, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(handler.on_text, text, **kwargs)
)
async def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when agent finishes."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_agent_finish):
await handler.on_agent_finish(finish, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_agent_finish, finish, **kwargs
),
)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = handlers

View File

@ -21,8 +21,12 @@ class OpenAICallbackHandler(BaseCallbackHandler):
"""Print out the prompts."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
"""Collect token usage."""
if response.llm_output is not None:
if "token_usage" in response.llm_output:
token_usage = response.llm_output["token_usage"]

View File

@ -46,6 +46,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
with self._lock:
self._callback_manager.on_llm_end(response, **kwargs)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
with self._lock:
self._callback_manager.on_llm_new_token(token, **kwargs)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:

View File

@ -23,6 +23,10 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:

View File

@ -0,0 +1,60 @@
"""Callback Handler streams to stdout on new llm token."""
import sys
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
sys.stdout.write(token)
sys.stdout.flush()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""

View File

@ -18,6 +18,10 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
for prompt in prompts:
st.write(prompt)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
pass

View File

@ -129,6 +129,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
)
self._start_trace(llm_run)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Handle a new token for an LLM run."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""End a trace for an LLM run."""
if not self._stack or not isinstance(self._stack[-1], LLMRun):

View File

@ -158,6 +158,13 @@ class Chain(BaseModel, ABC):
"""
inputs = self.prep_inputs(inputs)
if self.callback_manager.is_async:
await self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
else:
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
@ -166,8 +173,14 @@ class Chain(BaseModel, ABC):
try:
outputs = await self._acall(inputs)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_chain_error(e, verbose=self.verbose)
else:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
else:
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs)

View File

@ -83,3 +83,20 @@ class ChatVectorDBChain(Chain, BaseModel):
new_inputs["chat_history"] = chat_history_str
answer, _ = self.combine_docs_chain.combine_docs(docs, **new_inputs)
return {self.output_key: answer}
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
question = inputs["question"]
chat_history_str = _get_chat_history(inputs["chat_history"])
if chat_history_str:
new_question = await self.question_generator.arun(
question=question, chat_history=chat_history_str
)
else:
new_question = question
# TODO: This blocks the event loop, but it's not clear how to avoid it.
docs = self.vectorstore.similarity_search(new_question, k=4)
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer, _ = await self.combine_docs_chain.acombine_docs(docs, **new_inputs)
return {self.output_key: answer}

View File

@ -43,6 +43,12 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine documents into a single string."""
@abstractmethod
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents into a single string asynchronously."""
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
@ -51,6 +57,14 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
extra_return_dict[self.output_key] = output
return extra_return_dict
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
output, extra_return_dict = await self.acombine_docs(docs, **other_keys)
extra_return_dict[self.output_key] = output
return extra_return_dict
class AnalyzeDocumentChain(Chain, BaseModel):
"""Chain that splits documents, then analyzes it in pieces."""

View File

@ -140,6 +140,29 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
)
return self._process_results(results, docs, token_max, **kwargs)
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents in a map reduce manner.
Combine by mapping first chain over all documents, then reducing the results.
This reducing can be done recursively if needed (if there are many documents).
"""
results = await self.llm_chain.aapply(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
)
return self._process_results(results, docs, **kwargs)
def _process_results(
self,
results: List[Dict],
docs: List[Document],
token_max: int = 3000,
**kwargs: Any,
) -> Tuple[str, dict]:
question_result_key = self.llm_chain.output_key
result_docs = [
Document(page_content=r[question_result_key], metadata=docs[i].metadata)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from pydantic import BaseModel, Extra, root_validator
@ -98,8 +98,27 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel):
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
)
typed_results = cast(List[dict], results)
return self._process_results(docs, results)
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results.
"""
results = await self.llm_chain.aapply_and_parse(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
)
return self._process_results(docs, results)
def _process_results(
self,
docs: List[Document],
results: Sequence[Union[str, List[str], Dict[str, str]]],
) -> Tuple[str, dict]:
typed_results = cast(List[dict], results)
sorted_res = sorted(
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
)

View File

@ -84,6 +84,50 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain."""
inputs = self._construct_initial_inputs(docs, **kwargs)
res = self.initial_llm_chain.predict(**inputs)
refine_steps = [res]
for doc in docs[1:]:
base_inputs = self._construct_refine_inputs(doc, res)
inputs = {**base_inputs, **kwargs}
res = self.refine_llm_chain.predict(**inputs)
refine_steps.append(res)
return self._construct_result(refine_steps, res)
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain."""
inputs = self._construct_initial_inputs(docs, **kwargs)
res = await self.initial_llm_chain.apredict(**inputs)
refine_steps = [res]
for doc in docs[1:]:
base_inputs = self._construct_refine_inputs(doc, res)
inputs = {**base_inputs, **kwargs}
res = await self.refine_llm_chain.apredict(**inputs)
refine_steps.append(res)
return self._construct_result(refine_steps, res)
def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
if self.return_intermediate_steps:
extra_return_dict = {"intermediate_steps": refine_steps}
else:
extra_return_dict = {}
return res, extra_return_dict
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
base_inputs = {
self.document_variable_name: self.document_prompt.format(**document_info),
self.initial_response_name: res,
}
return base_inputs
def _construct_initial_inputs(
self, docs: List[Document], **kwargs: Any
) -> Dict[str, Any]:
base_info = {"page_content": docs[0].page_content}
base_info.update(docs[0].metadata)
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
@ -91,28 +135,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
self.document_variable_name: self.document_prompt.format(**document_info)
}
inputs = {**base_inputs, **kwargs}
res = self.initial_llm_chain.predict(**inputs)
refine_steps = [res]
for doc in docs[1:]:
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
document_info = {
k: base_info[k] for k in self.document_prompt.input_variables
}
base_inputs = {
self.document_variable_name: self.document_prompt.format(
**document_info
),
self.initial_response_name: res,
}
inputs = {**base_inputs, **kwargs}
res = self.refine_llm_chain.predict(**inputs)
refine_steps.append(res)
if self.return_intermediate_steps:
extra_return_dict = {"intermediate_steps": refine_steps}
else:
extra_return_dict = {}
return res, extra_return_dict
return inputs
@property
def _chain_type(self) -> str:

View File

@ -84,6 +84,14 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
# Call predict on the LLM.
return self.llm_chain.predict(**inputs), {}
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return await self.llm_chain.apredict(**inputs), {}
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"

View File

@ -61,7 +61,7 @@ class LLMChain(Chain, BaseModel):
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
prompts, stop = await self.aprep_prompts(input_list)
response = await self.llm.agenerate(prompts, stop=stop)
return response
@ -86,6 +86,32 @@ class LLMChain(Chain, BaseModel):
prompts.append(prompt)
return prompts, stop
async def aprep_prompts(
self, input_list: List[Dict[str, Any]]
) -> Tuple[List[str], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format(**selected_inputs)
_colored_text = get_colored_text(prompt, "green")
_text = "Prompt after formatting:\n" + _colored_text
if self.callback_manager.is_async:
await self.callback_manager.on_text(
_text, end="\n", verbose=self.verbose
)
else:
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = self.generate(input_list)
@ -156,6 +182,11 @@ class LLMChain(Chain, BaseModel):
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = self.apply(input_list)
return self._parse_result(result)
def _parse_result(
self, result: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
new_result = []
for res in result:
@ -165,6 +196,13 @@ class LLMChain(Chain, BaseModel):
else:
return result
async def aapply_and_parse(
self, input_list: List[Dict[str, Any]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = await self.aapply(input_list)
return self._parse_result(result)
@property
def _chain_type(self) -> str:
return "llm_chain"

View File

@ -1,6 +1,7 @@
"""Load question answering chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
@ -31,14 +32,19 @@ def _load_map_rerank_chain(
document_variable_name: str = "context",
rank_key: str = "score",
answer_key: str = "answer",
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> MapRerankDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
llm_chain = LLMChain(
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
)
return MapRerankDocumentsChain(
llm_chain=llm_chain,
rank_key=rank_key,
answer_key=answer_key,
document_variable_name=document_variable_name,
verbose=verbose,
callback_manager=callback_manager,
**kwargs,
)
@ -48,14 +54,18 @@ def _load_stuff_chain(
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "context",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
llm_chain = LLMChain(
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
)
# TODO: document prompt
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
verbose=verbose,
callback_manager=callback_manager,
**kwargs,
)
@ -70,16 +80,28 @@ def _load_map_reduce_chain(
reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None,
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
map_chain = LLMChain(
llm=llm,
prompt=question_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
reduce_chain = LLMChain(
llm=_reduce_llm,
prompt=combine_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
# TODO: document prompt
combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
verbose=verbose,
callback_manager=callback_manager,
)
if collapse_prompt is None:
collapse_chain = None
@ -95,8 +117,11 @@ def _load_map_reduce_chain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose,
callback_manager=callback_manager,
),
document_variable_name=combine_document_variable_name,
verbose=verbose,
callback_manager=callback_manager,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
@ -104,6 +129,7 @@ def _load_map_reduce_chain(
document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
verbose=verbose,
callback_manager=callback_manager,
**kwargs,
)
@ -116,17 +142,29 @@ def _load_refine_chain(
initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None,
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
initial_chain = LLMChain(
llm=llm,
prompt=question_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
_refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
refine_chain = LLMChain(
llm=_refine_llm,
prompt=refine_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
verbose=verbose,
callback_manager=callback_manager,
**kwargs,
)
@ -135,6 +173,7 @@ def load_qa_chain(
llm: BaseLLM,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load question answering chain.
@ -145,6 +184,7 @@ def load_qa_chain(
"map_reduce", and "refine".
verbose: Whether chains should be run in verbose mode or not. Note that this
applies to all chains that make up the final chain.
callback_manager: Callback manager to use for the chain.
Returns:
A chain to use for question answering.
@ -160,4 +200,6 @@ def load_qa_chain(
f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}"
)
return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)
return loader_mapping[chain_type](
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
)

View File

@ -165,14 +165,25 @@ class BaseLLM(BaseModel, ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = await self._agenerate(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
else:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
params = self.dict()
@ -184,14 +195,31 @@ class BaseLLM(BaseModel, ABC):
missing_prompts,
) = get_prompts(params, prompts)
if len(missing_prompts) > 0:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__},
missing_prompts,
verbose=self.verbose,
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
{"name": self.__class__.__name__},
missing_prompts,
verbose=self.verbose,
)
try:
new_results = await self._agenerate(missing_prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
else:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(
new_results, verbose=self.verbose
)
else:
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts

View File

@ -42,6 +42,27 @@ def update_token_usage(
token_usage[_key] += response["usage"][_key]
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
"""Update response from the stream response."""
response["choices"][0]["text"] += stream_response["choices"][0]["text"]
response["choices"][0]["finish_reason"] = stream_response["choices"][0][
"finish_reason"
]
response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"]
def _streaming_response_template() -> Dict[str, Any]:
return {
"choices": [
{
"text": "",
"finish_reason": None,
"logprobs": None,
}
]
}
class BaseOpenAI(BaseLLM, BaseModel):
"""Wrapper around OpenAI large language models.
@ -88,6 +109,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
"""Adjust the probability of specific tokens being generated."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
class Config:
"""Configuration for this pydantic object."""
@ -129,6 +152,10 @@ class BaseOpenAI(BaseLLM, BaseModel):
"Could not import openai python package. "
"Please it install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")
return values
@property
@ -215,8 +242,24 @@ class BaseOpenAI(BaseLLM, BaseModel):
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
if self.streaming:
if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True
response = _streaming_response_template()
for stream_resp in self.completion_with_retry(
prompt=_prompts, **params
):
self.callback_manager.on_llm_new_token(
stream_resp["choices"][0]["text"], verbose=self.verbose
)
_update_response(response, stream_resp)
choices.extend(response["choices"])
else:
response = self.completion_with_retry(prompt=_prompts, **params)
choices.extend(response["choices"])
if not self.streaming:
# Can't update token usage if streaming
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
@ -232,9 +275,29 @@ class BaseOpenAI(BaseLLM, BaseModel):
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
if self.streaming:
if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True
response = _streaming_response_template()
async for stream_resp in await self.acompletion_with_retry(
prompt=_prompts, **params
):
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
stream_resp["choices"][0]["text"], verbose=self.verbose
)
else:
self.callback_manager.on_llm_new_token(
stream_resp["choices"][0]["text"], verbose=self.verbose
)
_update_response(response, stream_resp)
choices.extend(response["choices"])
else:
response = await self.acompletion_with_retry(prompt=_prompts, **params)
choices.extend(response["choices"])
if not self.streaming:
# Can't update token usage if streaming
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
@ -304,6 +367,13 @@ class BaseOpenAI(BaseLLM, BaseModel):
for token in generator:
yield token
"""
params = self.prep_streaming_params(stop)
generator = self.client.create(prompt=prompt, **params)
return generator
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""Prepare the params for streaming."""
params = self._invocation_params
if params["best_of"] != 1:
raise ValueError("OpenAI only supports best_of == 1 for streaming")
@ -312,9 +382,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
params["stream"] = True
generator = self.client.create(prompt=prompt, **params)
return generator
return params
@property
def _invocation_params(self) -> Dict[str, Any]:

View File

@ -5,9 +5,11 @@ from typing import Generator
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI
from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_openai_call() -> None:
@ -77,9 +79,66 @@ def test_openai_streaming_error() -> None:
llm.stream("I'm Pickle Rick")
def test_openai_streaming_best_of_error() -> None:
"""Test validation for streaming fails if best_of is not 1."""
with pytest.raises(ValueError):
OpenAI(best_of=2, streaming=True)
def test_openai_streaming_n_error() -> None:
"""Test validation for streaming fails if n is not 1."""
with pytest.raises(ValueError):
OpenAI(n=2, streaming=True)
def test_openai_streaming_multiple_prompts_error() -> None:
"""Test validation for streaming fails if multiple prompts are given."""
with pytest.raises(ValueError):
OpenAI(streaming=True).generate(["I'm Pickle Rick", "I'm Pickle Rick"])
def test_openai_streaming_call() -> None:
"""Test valid call to openai."""
llm = OpenAI(max_tokens=10, streaming=True)
output = llm("Say foo:")
assert isinstance(output, str)
def test_openai_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = OpenAI(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
llm("Write me a sentence with 100 words.")
assert callback_handler.llm_streams == 10
@pytest.mark.asyncio
async def test_openai_async_generate() -> None:
"""Test async generation."""
llm = OpenAI(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
@pytest.mark.asyncio
async def test_openai_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = OpenAI(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
result = await llm.agenerate(["Write me a sentence with 100 words."])
assert callback_handler.llm_streams == 10
assert isinstance(result, LLMResult)

View File

@ -3,12 +3,12 @@ from typing import Any, Dict, List, Union
from pydantic import BaseModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
"""Fake callback handler for testing."""
class BaseFakeCallbackHandler(BaseModel):
"""Base fake callback handler for testing."""
starts: int = 0
ends: int = 0
@ -44,10 +44,15 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
chain_ends: int = 0
llm_starts: int = 0
llm_ends: int = 0
llm_streams: int = 0
tool_starts: int = 0
tool_ends: int = 0
agent_ends: int = 0
class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
"""Fake callback handler for testing."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
@ -55,6 +60,10 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
self.llm_starts += 1
self.starts += 1
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.llm_streams += 1
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.llm_ends += 1
@ -110,3 +119,74 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
"""Run when agent ends running."""
self.agent_ends += 1
self.ends += 1
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
"""Fake async callback handler for testing."""
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
self.llm_starts += 1
self.starts += 1
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.llm_streams += 1
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.llm_ends += 1
self.ends += 1
async def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.errors += 1
async def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.chain_starts += 1
self.starts += 1
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.chain_ends += 1
self.ends += 1
async def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
self.errors += 1
async def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.tool_starts += 1
self.starts += 1
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.tool_ends += 1
self.ends += 1
async def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
self.errors += 1
async def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent is ending."""
self.text += 1
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.agent_ends += 1
self.ends += 1

View File

@ -1,13 +1,24 @@
"""Test CallbackManager."""
from typing import Tuple
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
import pytest
from langchain.callbacks.base import (
AsyncCallbackManager,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import AgentAction, AgentFinish, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.callbacks.fake_callback_handler import (
BaseFakeCallbackHandler,
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
def _test_callback_manager(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [])
@ -20,6 +31,27 @@ def _test_callback_manager(
manager.on_tool_end("")
manager.on_tool_error(Exception())
manager.on_agent_finish(AgentFinish(log="", return_values={}))
_check_num_calls(handlers)
async def _test_callback_manager_async(
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
await manager.on_llm_start({}, [])
await manager.on_llm_end(LLMResult(generations=[]))
await manager.on_llm_error(Exception())
await manager.on_chain_start({"name": "foo"}, {})
await manager.on_chain_end({})
await manager.on_chain_error(Exception())
await manager.on_tool_start({}, AgentAction("", "", ""))
await manager.on_tool_end("")
await manager.on_tool_error(Exception())
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
_check_num_calls(handlers)
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
for handler in handlers:
if handler.always_verbose:
assert handler.starts == 3
@ -128,3 +160,21 @@ def test_shared_callback_manager() -> None:
manager1.add_handler(handler1)
manager2.add_handler(handler2)
_test_callback_manager(manager1, handler1, handler2)
@pytest.mark.asyncio
async def test_async_callback_manager() -> None:
"""Test the AsyncCallbackManager."""
handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
handler2 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2])
await _test_callback_manager_async(manager, handler1, handler2)
@pytest.mark.asyncio
async def test_async_callback_manager_sync_handler() -> None:
"""Test the AsyncCallbackManager."""
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2])
await _test_callback_manager_async(manager, handler1, handler2)