mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Enable streaming for OpenAI LLM (#986)
* Support a callback `on_llm_new_token` that users can implement when `OpenAI.streaming` is set to `True`
This commit is contained in:
parent
f05f025e41
commit
caa8e4742e
@ -14,7 +14,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "70c4e529",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
@ -36,7 +38,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "01c46e92",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
@ -56,7 +60,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "433363a5",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# loaders = [....]\n",
|
||||
@ -75,9 +81,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"id": "a8930cf7",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -106,9 +114,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 5,
|
||||
"id": "7b4110f3",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore)"
|
||||
@ -126,7 +136,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "7fe3e730",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat_history = []\n",
|
||||
@ -136,9 +148,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 7,
|
||||
"id": "bfff9cc8",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -146,7 +160,7 @@
|
||||
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -165,9 +179,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 8,
|
||||
"id": "00b4cf00",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat_history = [(query, result[\"answer\"])]\n",
|
||||
@ -177,9 +193,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 9,
|
||||
"id": "f01828d1",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -187,7 +205,7 @@
|
||||
"' Justice Stephen Breyer'"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -196,10 +214,90 @@
|
||||
"result['answer']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2324cdc6-98bf-4708-b8cd-02a98b1e5b67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chat Vector DB with streaming to `stdout`\n",
|
||||
"\n",
|
||||
"Output from the chain will be streamed to `stdout` token by token in this example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "2efacec3-2690-4b05-8de3-a32fd2ac3911",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.llm import LLMChain\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n",
|
||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||
"\n",
|
||||
"# Construct a ChatVectorDBChain with a streaming llm for combine docs\n",
|
||||
"# and a separate, non-streaming llm for question generation\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"streaming_llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"\n",
|
||||
"question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)\n",
|
||||
"doc_chain = load_qa_chain(streaming_llm, chain_type=\"stuff\", prompt=QA_PROMPT)\n",
|
||||
"\n",
|
||||
"qa = ChatVectorDBChain(vectorstore=vectorstore, combine_docs_chain=doc_chain, question_generator=question_generator)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "fd6d43f4-7428-44a4-81bc-26fe88a98762",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat_history = []\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"result = qa({\"question\": query, \"chat_history\": chat_history})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "5ab38978-f3e8-4fa7-808c-c79dec48379a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" Justice Stephen Breyer"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat_history = [(query, result[\"answer\"])]\n",
|
||||
"query = \"Did he mention who she suceeded\"\n",
|
||||
"result = qa({\"question\": query, \"chat_history\": chat_history})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d0f869c6",
|
||||
"id": "a7ea93ff-1899-4171-9c24-85df20ae1a3d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@ -221,7 +319,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -18,7 +18,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "df924055",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI"
|
||||
@ -207,14 +209,6 @@
|
||||
"source": [
|
||||
"llm.get_num_tokens(\"what a joke\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b004ffdd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -8,6 +8,7 @@ They are split into two categories:
|
||||
1. `Generic Functionality <./generic_how_to.html>`_: Covering generic functionality all LLMs should have.
|
||||
2. `Integrations <./integrations.html>`_: Covering integrations with various LLM providers.
|
||||
3. `Asynchronous <./async_llm.html>`_: Covering asynchronous functionality.
|
||||
4. `Streaming <./streaming_llm.html>`_: Covering streaming functionality.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
140
docs/modules/llms/streaming_llm.ipynb
Normal file
140
docs/modules/llms/streaming_llm.ipynb
Normal file
@ -0,0 +1,140 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6eaf7e66-f49c-42da-8d11-22ea13bef718",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Streaming with LLMs\n",
|
||||
"\n",
|
||||
"LangChain provides streaming support for LLMs. Currently, we only support streaming for the `OpenAI` LLM implementation, but streaming support for other LLM implementations is on the roadmap. To utilize streaming, use a [`CallbackHandler`](https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/base.py) that implements `on_llm_new_token`. In this example, we are using [`StreamingStdOutCallbackHandler`]()."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"Verse 1\n",
|
||||
"I'm sippin' on sparkling water,\n",
|
||||
"It's so refreshing and light,\n",
|
||||
"It's the perfect way to quench my thirst,\n",
|
||||
"On a hot summer night.\n",
|
||||
"\n",
|
||||
"Chorus\n",
|
||||
"Sparkling water, sparkling water,\n",
|
||||
"It's the best way to stay hydrated,\n",
|
||||
"It's so refreshing and light,\n",
|
||||
"It's the perfect way to stay alive.\n",
|
||||
"\n",
|
||||
"Verse 2\n",
|
||||
"I'm sippin' on sparkling water,\n",
|
||||
"It's so bubbly and bright,\n",
|
||||
"It's the perfect way to cool me down,\n",
|
||||
"On a hot summer night.\n",
|
||||
"\n",
|
||||
"Chorus\n",
|
||||
"Sparkling water, sparkling water,\n",
|
||||
"It's the best way to stay hydrated,\n",
|
||||
"It's so refreshing and light,\n",
|
||||
"It's the perfect way to stay alive.\n",
|
||||
"\n",
|
||||
"Verse 3\n",
|
||||
"I'm sippin' on sparkling water,\n",
|
||||
"It's so crisp and clean,\n",
|
||||
"It's the perfect way to keep me going,\n",
|
||||
"On a hot summer day.\n",
|
||||
"\n",
|
||||
"Chorus\n",
|
||||
"Sparkling water, sparkling water,\n",
|
||||
"It's the best way to stay hydrated,\n",
|
||||
"It's so refreshing and light,\n",
|
||||
"It's the perfect way to stay alive."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"resp = llm(\"Write me a song about sparkling water.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "61fb6de7-c6c8-48d0-a48e-1204c027a23c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"We still have access to the end `LLMResult` if using `generate`. However, `token_usage` is not currently supported for streaming."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "a35373f1-9ee6-4753-a343-5aee749b8527",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"Q: What did the fish say when it hit the wall?\n",
|
||||
"A: Dam!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LLMResult(generations=[[Generation(text='\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!', generation_info={'finish_reason': 'stop', 'logprobs': None})]], llm_output={'token_usage': {}})"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm.generate([\"Tell me a joke.\"])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -375,6 +375,22 @@ class AgentExecutor(Chain, BaseModel):
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
|
||||
async def _areturn(
|
||||
self, output: AgentFinish, intermediate_steps: list
|
||||
) -> Dict[str, Any]:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
|
||||
def _take_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, Tool],
|
||||
@ -428,6 +444,90 @@ class AgentExecutor(Chain, BaseModel):
|
||||
return AgentFinish({self.agent.return_values[0]: observation}, "")
|
||||
return output, observation
|
||||
|
||||
async def _atake_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, Tool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
) -> Union[AgentFinish, Tuple[AgentAction, str]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
|
||||
Override this to take control of how the agent makes and acts on choices.
|
||||
"""
|
||||
# Call the LLM to see what to do.
|
||||
output = await self.agent.aplan(intermediate_steps, **inputs)
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
# Otherwise we lookup the tool
|
||||
if output.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[output.tool]
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_start(
|
||||
{"name": str(tool.func)[:60] + "..."},
|
||||
output,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": str(tool.func)[:60] + "..."},
|
||||
output,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = (
|
||||
await tool.coroutine(output.tool_input)
|
||||
if tool.coroutine
|
||||
# If the tool is not a coroutine, we run it in the executor
|
||||
# to avoid blocking the event loop.
|
||||
else await asyncio.get_event_loop().run_in_executor(
|
||||
None, tool.func, output.tool_input
|
||||
)
|
||||
)
|
||||
color = color_mapping[output.tool]
|
||||
return_direct = tool.return_direct
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_tool_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
else:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_start(
|
||||
{"name": "N/A"}, output, verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": "N/A"}, output, verbose=self.verbose
|
||||
)
|
||||
observation = f"{output.tool} is not a valid tool, try another one."
|
||||
color = None
|
||||
return_direct = False
|
||||
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_end(
|
||||
observation,
|
||||
color=color,
|
||||
observation_prefix=self.agent.observation_prefix,
|
||||
llm_prefix=llm_prefix,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_end(
|
||||
observation,
|
||||
color=color,
|
||||
observation_prefix=self.agent.observation_prefix,
|
||||
llm_prefix=llm_prefix,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if return_direct:
|
||||
# Set the log to "" because we do not want to log it.
|
||||
return AgentFinish({self.agent.return_values[0]: observation}, "")
|
||||
return output, observation
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Run text through and get agent response."""
|
||||
# Make sure that every tool is synchronous (not a coroutine)
|
||||
@ -486,58 +586,15 @@ class AgentExecutor(Chain, BaseModel):
|
||||
iterations = 0
|
||||
# We now enter the agent loop (until it returns something).
|
||||
while self._should_continue(iterations):
|
||||
# Call the LLM to see what to do.
|
||||
output = await self.agent.aplan(intermediate_steps, **inputs)
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return self._return(output, intermediate_steps)
|
||||
next_step_output = await self._atake_next_step(
|
||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||
)
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
return await self._areturn(next_step_output, intermediate_steps)
|
||||
|
||||
# Otherwise we lookup the tool
|
||||
if output.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[output.tool]
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": str(tool.func)[:60] + "..."},
|
||||
output,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = (
|
||||
await tool.coroutine(output.tool_input)
|
||||
if tool.coroutine
|
||||
# If the tool is not a coroutine, we run it in the executor
|
||||
# to avoid blocking the event loop.
|
||||
else await asyncio.get_event_loop().run_in_executor(
|
||||
None, tool.func, output.tool_input
|
||||
)
|
||||
)
|
||||
color = color_mapping[output.tool]
|
||||
return_direct = tool.return_direct
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_tool_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
else:
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": "N/A"}, output, verbose=self.verbose
|
||||
)
|
||||
observation = f"{output.tool} is not a valid tool, try another one."
|
||||
color = None
|
||||
return_direct = False
|
||||
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
||||
self.callback_manager.on_tool_end(
|
||||
observation,
|
||||
color=color,
|
||||
observation_prefix=self.agent.observation_prefix,
|
||||
llm_prefix=llm_prefix,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
intermediate_steps.append((output, observation))
|
||||
if return_direct:
|
||||
# Set the log to "" because we do not want to log it.
|
||||
output = AgentFinish({self.agent.return_values[0]: observation}, "")
|
||||
return self._return(output, intermediate_steps)
|
||||
intermediate_steps.append(next_step_output)
|
||||
iterations += 1
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return self._return(output, intermediate_steps)
|
||||
return await self._areturn(output, intermediate_steps)
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
@ -32,63 +33,72 @@ class BaseCallbackHandler(ABC):
|
||||
@abstractmethod
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
) -> Any:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
) -> Any:
|
||||
"""Run when chain errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
) -> Any:
|
||||
"""Run when tool errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
def on_text(self, text: str, **kwargs: Any) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
@abstractmethod
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Whether the callback manager is async."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
@ -126,6 +136,15 @@ class CallbackManager(BaseCallbackManager):
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_new_token(
|
||||
self, token: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_new_token(token, **kwargs)
|
||||
|
||||
def on_llm_end(
|
||||
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
@ -239,3 +258,287 @@ class CallbackManager(BaseCallbackManager):
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = handlers
|
||||
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
async def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
async def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class AsyncCallbackManager(BaseCallbackManager):
|
||||
"""Async callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the handler is async."""
|
||||
return True
|
||||
|
||||
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_llm_start):
|
||||
await handler.on_llm_start(serialized, prompts, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
handler.on_llm_start, serialized, prompts, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
async def on_llm_new_token(
|
||||
self, token: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_llm_new_token):
|
||||
await handler.on_llm_new_token(token, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
handler.on_llm_new_token, token, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
async def on_llm_end(
|
||||
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_llm_end):
|
||||
await handler.on_llm_end(response, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(handler.on_llm_end, response, **kwargs),
|
||||
)
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_llm_error):
|
||||
await handler.on_llm_error(error, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(handler.on_llm_error, error, **kwargs),
|
||||
)
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_chain_start):
|
||||
await handler.on_chain_start(serialized, inputs, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
handler.on_chain_start, serialized, inputs, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_chain_end):
|
||||
await handler.on_chain_end(outputs, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(handler.on_chain_end, outputs, **kwargs),
|
||||
)
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_chain_error):
|
||||
await handler.on_chain_error(error, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(handler.on_chain_error, error, **kwargs),
|
||||
)
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
action: AgentAction,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_tool_start):
|
||||
await handler.on_tool_start(serialized, action, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
handler.on_tool_start, serialized, action, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
async def on_tool_end(
|
||||
self, output: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_tool_end):
|
||||
await handler.on_tool_end(output, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(handler.on_tool_end, output, **kwargs),
|
||||
)
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_tool_error):
|
||||
await handler.on_tool_error(error, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(handler.on_tool_error, error, **kwargs),
|
||||
)
|
||||
|
||||
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run when text is printed."""
|
||||
for handler in self.handlers:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_text):
|
||||
await handler.on_text(text, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(handler.on_text, text, **kwargs)
|
||||
)
|
||||
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when agent finishes."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_agent_finish):
|
||||
await handler.on_agent_finish(finish, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
handler.on_agent_finish, finish, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = handlers
|
||||
|
@ -21,8 +21,12 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
"""Print out the prompts."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Print out the token."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
"""Collect token usage."""
|
||||
if response.llm_output is not None:
|
||||
if "token_usage" in response.llm_output:
|
||||
token_usage = response.llm_output["token_usage"]
|
||||
|
@ -46,6 +46,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_end(response, **kwargs)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_new_token(token, **kwargs)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
|
@ -23,6 +23,10 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
|
60
langchain/callbacks/streaming_stdout.py
Normal file
60
langchain/callbacks/streaming_stdout.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""Callback Handler streams to stdout on new llm token."""
|
||||
import sys
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
@ -18,6 +18,10 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
for prompt in prompts:
|
||||
st.write(prompt)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
@ -129,6 +129,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Handle a new token for an LLM run."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""End a trace for an LLM run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
||||
|
@ -158,6 +158,13 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
@ -166,8 +173,14 @@ class Chain(BaseModel, ABC):
|
||||
try:
|
||||
outputs = await self._acall(inputs)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||
|
||||
|
@ -83,3 +83,20 @@ class ChatVectorDBChain(Chain, BaseModel):
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
answer, _ = self.combine_docs_chain.combine_docs(docs, **new_inputs)
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
question = inputs["question"]
|
||||
chat_history_str = _get_chat_history(inputs["chat_history"])
|
||||
if chat_history_str:
|
||||
new_question = await self.question_generator.arun(
|
||||
question=question, chat_history=chat_history_str
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
# TODO: This blocks the event loop, but it's not clear how to avoid it.
|
||||
docs = self.vectorstore.similarity_search(new_question, k=4)
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
answer, _ = await self.combine_docs_chain.acombine_docs(docs, **new_inputs)
|
||||
return {self.output_key: answer}
|
||||
|
@ -43,6 +43,12 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string."""
|
||||
|
||||
@abstractmethod
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string asynchronously."""
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
@ -51,6 +57,14 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
output, extra_return_dict = await self.acombine_docs(docs, **other_keys)
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
|
||||
class AnalyzeDocumentChain(Chain, BaseModel):
|
||||
"""Chain that splits documents, then analyzes it in pieces."""
|
||||
|
@ -140,6 +140,29 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
)
|
||||
return self._process_results(results, docs, token_max, **kwargs)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map reduce manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reducing the results.
|
||||
This reducing can be done recursively if needed (if there are many documents).
|
||||
"""
|
||||
results = await self.llm_chain.aapply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
)
|
||||
return self._process_results(results, docs, **kwargs)
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
results: List[Dict],
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
question_result_key = self.llm_chain.output_key
|
||||
result_docs = [
|
||||
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
@ -98,8 +98,27 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
)
|
||||
typed_results = cast(List[dict], results)
|
||||
return self._process_results(docs, results)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map rerank manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reranking the results.
|
||||
"""
|
||||
results = await self.llm_chain.aapply_and_parse(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
)
|
||||
return self._process_results(docs, results)
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
docs: List[Document],
|
||||
results: Sequence[Union[str, List[str], Dict[str, str]]],
|
||||
) -> Tuple[str, dict]:
|
||||
typed_results = cast(List[dict], results)
|
||||
sorted_res = sorted(
|
||||
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
|
||||
)
|
||||
|
@ -84,6 +84,50 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||
res = self.initial_llm_chain.predict(**inputs)
|
||||
refine_steps = [res]
|
||||
for doc in docs[1:]:
|
||||
base_inputs = self._construct_refine_inputs(doc, res)
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
res = self.refine_llm_chain.predict(**inputs)
|
||||
refine_steps.append(res)
|
||||
return self._construct_result(refine_steps, res)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||
res = await self.initial_llm_chain.apredict(**inputs)
|
||||
refine_steps = [res]
|
||||
for doc in docs[1:]:
|
||||
base_inputs = self._construct_refine_inputs(doc, res)
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
res = await self.refine_llm_chain.apredict(**inputs)
|
||||
refine_steps.append(res)
|
||||
return self._construct_result(refine_steps, res)
|
||||
|
||||
def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
|
||||
if self.return_intermediate_steps:
|
||||
extra_return_dict = {"intermediate_steps": refine_steps}
|
||||
else:
|
||||
extra_return_dict = {}
|
||||
return res, extra_return_dict
|
||||
|
||||
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
||||
base_inputs = {
|
||||
self.document_variable_name: self.document_prompt.format(**document_info),
|
||||
self.initial_response_name: res,
|
||||
}
|
||||
return base_inputs
|
||||
|
||||
def _construct_initial_inputs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
base_info = {"page_content": docs[0].page_content}
|
||||
base_info.update(docs[0].metadata)
|
||||
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
||||
@ -91,28 +135,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
self.document_variable_name: self.document_prompt.format(**document_info)
|
||||
}
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
res = self.initial_llm_chain.predict(**inputs)
|
||||
refine_steps = [res]
|
||||
for doc in docs[1:]:
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
document_info = {
|
||||
k: base_info[k] for k in self.document_prompt.input_variables
|
||||
}
|
||||
base_inputs = {
|
||||
self.document_variable_name: self.document_prompt.format(
|
||||
**document_info
|
||||
),
|
||||
self.initial_response_name: res,
|
||||
}
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
res = self.refine_llm_chain.predict(**inputs)
|
||||
refine_steps.append(res)
|
||||
if self.return_intermediate_steps:
|
||||
extra_return_dict = {"intermediate_steps": refine_steps}
|
||||
else:
|
||||
extra_return_dict = {}
|
||||
return res, extra_return_dict
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
@ -84,6 +84,14 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
# Call predict on the LLM.
|
||||
return self.llm_chain.predict(**inputs), {}
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return await self.llm_chain.apredict(**inputs), {}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "stuff_documents_chain"
|
||||
|
@ -61,7 +61,7 @@ class LLMChain(Chain, BaseModel):
|
||||
|
||||
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list)
|
||||
prompts, stop = await self.aprep_prompts(input_list)
|
||||
response = await self.llm.agenerate(prompts, stop=stop)
|
||||
return response
|
||||
|
||||
@ -86,6 +86,32 @@ class LLMChain(Chain, BaseModel):
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
async def aprep_prompts(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
) -> Tuple[List[str], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt, "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
_text, end="\n", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
response = self.generate(input_list)
|
||||
@ -156,6 +182,11 @@ class LLMChain(Chain, BaseModel):
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = self.apply(input_list)
|
||||
return self._parse_result(result)
|
||||
|
||||
def _parse_result(
|
||||
self, result: List[Dict[str, str]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
if self.prompt.output_parser is not None:
|
||||
new_result = []
|
||||
for res in result:
|
||||
@ -165,6 +196,13 @@ class LLMChain(Chain, BaseModel):
|
||||
else:
|
||||
return result
|
||||
|
||||
async def aapply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = await self.aapply(input_list)
|
||||
return self._parse_result(result)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_chain"
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Load question answering chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
@ -31,14 +32,19 @@ def _load_map_rerank_chain(
|
||||
document_variable_name: str = "context",
|
||||
rank_key: str = "score",
|
||||
answer_key: str = "answer",
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapRerankDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
|
||||
)
|
||||
return MapRerankDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
rank_key=rank_key,
|
||||
answer_key=answer_key,
|
||||
document_variable_name=document_variable_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -48,14 +54,18 @@ def _load_stuff_chain(
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_variable_name: str = "context",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
|
||||
)
|
||||
# TODO: document prompt
|
||||
return StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -70,16 +80,28 @@ def _load_map_reduce_chain(
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
map_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=question_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
_reduce_llm = reduce_llm or llm
|
||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
||||
reduce_chain = LLMChain(
|
||||
llm=_reduce_llm,
|
||||
prompt=combine_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
# TODO: document prompt
|
||||
combine_document_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain,
|
||||
document_variable_name=combine_document_variable_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
if collapse_prompt is None:
|
||||
collapse_chain = None
|
||||
@ -95,8 +117,11 @@ def _load_map_reduce_chain(
|
||||
llm=_collapse_llm,
|
||||
prompt=collapse_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
@ -104,6 +129,7 @@ def _load_map_reduce_chain(
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
collapse_document_chain=collapse_chain,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -116,17 +142,29 @@ def _load_refine_chain(
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
initial_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=question_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
_refine_llm = refine_llm or llm
|
||||
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
|
||||
refine_chain = LLMChain(
|
||||
llm=_refine_llm,
|
||||
prompt=refine_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
initial_response_name=initial_response_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -135,6 +173,7 @@ def load_qa_chain(
|
||||
llm: BaseLLM,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering chain.
|
||||
@ -145,6 +184,7 @@ def load_qa_chain(
|
||||
"map_reduce", and "refine".
|
||||
verbose: Whether chains should be run in verbose mode or not. Note that this
|
||||
applies to all chains that make up the final chain.
|
||||
callback_manager: Callback manager to use for the chain.
|
||||
|
||||
Returns:
|
||||
A chain to use for question answering.
|
||||
@ -160,4 +200,6 @@ def load_qa_chain(
|
||||
f"Got unsupported chain type: {chain_type}. "
|
||||
f"Should be one of {loader_mapping.keys()}"
|
||||
)
|
||||
return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)
|
||||
return loader_mapping[chain_type](
|
||||
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
|
||||
)
|
||||
|
@ -165,14 +165,25 @@ class BaseLLM(BaseModel, ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
try:
|
||||
output = await self._agenerate(prompts, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
params = self.dict()
|
||||
@ -184,14 +195,31 @@ class BaseLLM(BaseModel, ABC):
|
||||
missing_prompts,
|
||||
) = get_prompts(params, prompts)
|
||||
if len(missing_prompts) > 0:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
missing_prompts,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
|
||||
{"name": self.__class__.__name__},
|
||||
missing_prompts,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
try:
|
||||
new_results = await self._agenerate(missing_prompts, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(
|
||||
new_results, verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
|
@ -42,6 +42,27 @@ def update_token_usage(
|
||||
token_usage[_key] += response["usage"][_key]
|
||||
|
||||
|
||||
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
||||
"""Update response from the stream response."""
|
||||
response["choices"][0]["text"] += stream_response["choices"][0]["text"]
|
||||
response["choices"][0]["finish_reason"] = stream_response["choices"][0][
|
||||
"finish_reason"
|
||||
]
|
||||
response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"]
|
||||
|
||||
|
||||
def _streaming_response_template() -> Dict[str, Any]:
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"text": "",
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class BaseOpenAI(BaseLLM, BaseModel):
|
||||
"""Wrapper around OpenAI large language models.
|
||||
|
||||
@ -88,6 +109,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
||||
"""Adjust the probability of specific tokens being generated."""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -129,6 +152,10 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
||||
"Could not import openai python package. "
|
||||
"Please it install it with `pip install openai`."
|
||||
)
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
if values["streaming"] and values["best_of"] > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
return values
|
||||
|
||||
@property
|
||||
@ -215,8 +242,24 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
||||
# Includes prompt, completion, and total tokens used.
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
params["stream"] = True
|
||||
response = _streaming_response_template()
|
||||
for stream_resp in self.completion_with_retry(
|
||||
prompt=_prompts, **params
|
||||
):
|
||||
self.callback_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"], verbose=self.verbose
|
||||
)
|
||||
_update_response(response, stream_resp)
|
||||
choices.extend(response["choices"])
|
||||
else:
|
||||
response = self.completion_with_retry(prompt=_prompts, **params)
|
||||
choices.extend(response["choices"])
|
||||
if not self.streaming:
|
||||
# Can't update token usage if streaming
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
@ -232,9 +275,29 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
||||
# Includes prompt, completion, and total tokens used.
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
params["stream"] = True
|
||||
response = _streaming_response_template()
|
||||
async for stream_resp in await self.acompletion_with_retry(
|
||||
prompt=_prompts, **params
|
||||
):
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"], verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"], verbose=self.verbose
|
||||
)
|
||||
_update_response(response, stream_resp)
|
||||
choices.extend(response["choices"])
|
||||
else:
|
||||
response = await self.acompletion_with_retry(prompt=_prompts, **params)
|
||||
choices.extend(response["choices"])
|
||||
if not self.streaming:
|
||||
# Can't update token usage if streaming
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
@ -304,6 +367,13 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
||||
for token in generator:
|
||||
yield token
|
||||
"""
|
||||
params = self.prep_streaming_params(stop)
|
||||
generator = self.client.create(prompt=prompt, **params)
|
||||
|
||||
return generator
|
||||
|
||||
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""Prepare the params for streaming."""
|
||||
params = self._invocation_params
|
||||
if params["best_of"] != 1:
|
||||
raise ValueError("OpenAI only supports best_of == 1 for streaming")
|
||||
@ -312,9 +382,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
params["stream"] = True
|
||||
generator = self.client.create(prompt=prompt, **params)
|
||||
|
||||
return generator
|
||||
return params
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
|
@ -5,9 +5,11 @@ from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.llms.loading import load_llm
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.schema import LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_openai_call() -> None:
|
||||
@ -77,9 +79,66 @@ def test_openai_streaming_error() -> None:
|
||||
llm.stream("I'm Pickle Rick")
|
||||
|
||||
|
||||
def test_openai_streaming_best_of_error() -> None:
|
||||
"""Test validation for streaming fails if best_of is not 1."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(best_of=2, streaming=True)
|
||||
|
||||
|
||||
def test_openai_streaming_n_error() -> None:
|
||||
"""Test validation for streaming fails if n is not 1."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(n=2, streaming=True)
|
||||
|
||||
|
||||
def test_openai_streaming_multiple_prompts_error() -> None:
|
||||
"""Test validation for streaming fails if multiple prompts are given."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(streaming=True).generate(["I'm Pickle Rick", "I'm Pickle Rick"])
|
||||
|
||||
|
||||
def test_openai_streaming_call() -> None:
|
||||
"""Test valid call to openai."""
|
||||
llm = OpenAI(max_tokens=10, streaming=True)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_openai_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
llm = OpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
llm("Write me a sentence with 100 words.")
|
||||
assert callback_handler.llm_streams == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_async_generate() -> None:
|
||||
"""Test async generation."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
output = await llm.agenerate(["Hello, how are you?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_async_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
llm = OpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
result = await llm.agenerate(["Write me a sentence with 100 words."])
|
||||
assert callback_handler.llm_streams == 10
|
||||
assert isinstance(result, LLMResult)
|
||||
|
@ -3,12 +3,12 @@ from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
"""Base fake callback handler for testing."""
|
||||
|
||||
starts: int = 0
|
||||
ends: int = 0
|
||||
@ -44,10 +44,15 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
chain_ends: int = 0
|
||||
llm_starts: int = 0
|
||||
llm_ends: int = 0
|
||||
llm_streams: int = 0
|
||||
tool_starts: int = 0
|
||||
tool_ends: int = 0
|
||||
agent_ends: int = 0
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
@ -55,6 +60,10 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.llm_ends += 1
|
||||
@ -110,3 +119,74 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
|
||||
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||
"""Fake async callback handler for testing."""
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.llm_streams += 1
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
|
||||
async def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
async def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
async def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run when agent is ending."""
|
||||
self.text += 1
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
@ -1,13 +1,24 @@
|
||||
"""Test CallbackManager."""
|
||||
from typing import Tuple
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
AsyncCallbackManager,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
BaseFakeCallbackHandler,
|
||||
FakeAsyncCallbackHandler,
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
|
||||
|
||||
def _test_callback_manager(
|
||||
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||
manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [])
|
||||
@ -20,6 +31,27 @@ def _test_callback_manager(
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
_check_num_calls(handlers)
|
||||
|
||||
|
||||
async def _test_callback_manager_async(
|
||||
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
await manager.on_llm_start({}, [])
|
||||
await manager.on_llm_end(LLMResult(generations=[]))
|
||||
await manager.on_llm_error(Exception())
|
||||
await manager.on_chain_start({"name": "foo"}, {})
|
||||
await manager.on_chain_end({})
|
||||
await manager.on_chain_error(Exception())
|
||||
await manager.on_tool_start({}, AgentAction("", "", ""))
|
||||
await manager.on_tool_end("")
|
||||
await manager.on_tool_error(Exception())
|
||||
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
_check_num_calls(handlers)
|
||||
|
||||
|
||||
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
|
||||
for handler in handlers:
|
||||
if handler.always_verbose:
|
||||
assert handler.starts == 3
|
||||
@ -128,3 +160,21 @@ def test_shared_callback_manager() -> None:
|
||||
manager1.add_handler(handler1)
|
||||
manager2.add_handler(handler2)
|
||||
_test_callback_manager(manager1, handler1, handler2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeAsyncCallbackHandler()
|
||||
manager = AsyncCallbackManager([handler1, handler2])
|
||||
await _test_callback_manager_async(manager, handler1, handler2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager_sync_handler() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeAsyncCallbackHandler()
|
||||
manager = AsyncCallbackManager([handler1, handler2])
|
||||
await _test_callback_manager_async(manager, handler1, handler2)
|
||||
|
Loading…
Reference in New Issue
Block a user