mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
core[minor]: RFC Add astream_events to Runnables (#16172)
This PR adds `astream_events` method to Runnables to make it easier to
stream data from arbitrary chains.
* Streaming only works properly in async right now
* One should use `astream()` with if mixing in imperative code as might
be done with tool implementations
* Astream_log has been modified with minimal additive changes, so no
breaking changes are expected
* Underlying callback code / tracing code should be refactored at some
point to handle things more consistently (OK for now)
- ~~[ ] verify event for on_retry~~ does not work until we implement
streaming for retry
- ~~[ ] Any rrenaming? Should we rename "event" to "hook"?~~
- [ ] Any other feedback from community?
- [x] throw NotImplementedError for `RunnableEach` for now
## Example
See this [Example
Notebook](dbbc7fa0d6/docs/docs/modules/agents/how_to/streaming_events.ipynb
)
for an example with streaming in the context of an Agent
## Event Hooks Reference
Here is a reference table that shows some events that might be emitted
by the various Runnable objects.
Definitions for some of the Runnable are included after the table.
| event | name | chunk | input | output |
|----------------------|------------------|---------------------------------|-----------------------------------------------|-------------------------------------------------|
| on_chat_model_start | [model name] | | {"messages": [[SystemMessage,
HumanMessage]]} | |
| on_chat_model_stream | [model name] | AIMessageChunk(content="hello")
| | |
| on_chat_model_end | [model name] | | {"messages": [[SystemMessage,
HumanMessage]]} | {"generations": [...], "llm_output": None, ...} |
| on_llm_start | [model name] | | {'input': 'hello'} | |
| on_llm_stream | [model name] | 'Hello' | | |
| on_llm_end | [model name] | | 'Hello human!' |
| on_chain_start | format_docs | | | |
| on_chain_stream | format_docs | "hello world!, goodbye world!" | | |
| on_chain_end | format_docs | | [Document(...)] | "hello world!,
goodbye world!" |
| on_tool_start | some_tool | | {"x": 1, "y": "2"} | |
| on_tool_stream | some_tool | {"x": 1, "y": "2"} | | |
| on_tool_end | some_tool | | | {"x": 1, "y": "2"} |
| on_retriever_start | [retriever name] | | {"query": "hello"} | |
| on_retriever_chunk | [retriever name] | {documents: [...]} | | |
| on_retriever_end | [retriever name] | | {"query": "hello"} |
{documents: [...]} |
| on_prompt_start | [template_name] | | {"question": "hello"} | |
| on_prompt_end | [template_name] | | {"question": "hello"} |
ChatPromptValue(messages: [SystemMessage, ...]) |
Here are declarations associated with the events shown above:
`format_docs`:
```python
def format_docs(docs: List[Document]) -> str:
'''Format the docs.'''
return ", ".join([doc.page_content for doc in docs])
format_docs = RunnableLambda(format_docs)
```
`some_tool`:
```python
@tool
def some_tool(x: int, y: str) -> dict:
'''Some_tool.'''
return {"x": x, "y": y}
```
`prompt`:
```python
template = ChatPromptTemplate.from_messages(
[("system", "You are Cat Agent 007"), ("human", "{question}")]
).with_config({"run_name": "my_template", "tags": ["my_template"]})
```
This commit is contained in:
parent
f175bf7d7b
commit
177af65dc4
350
docs/docs/modules/agents/how_to/streaming_events.ipynb
Normal file
350
docs/docs/modules/agents/how_to/streaming_events.ipynb
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b69e747b-4e79-4caf-8f8b-c6e70275a31d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Event Streaming\n",
|
||||||
|
"\n",
|
||||||
|
"**NEW** This is a new API only works with recent versions of langchain-core!\n",
|
||||||
|
"\n",
|
||||||
|
"In this notebook, we'll see how to use `astream_events` to stream **token by token** from LLM calls used within the tools invoked by the agent. \n",
|
||||||
|
"\n",
|
||||||
|
"We will **only** stream tokens from LLMs used within tools and from no other LLMs (just to show that we can)! \n",
|
||||||
|
"\n",
|
||||||
|
"Feel free to adapt this example to the needs of your application.\n",
|
||||||
|
"\n",
|
||||||
|
"Our agent will use the OpenAI tools API for tool invocation, and we'll provide the agent with two tools:\n",
|
||||||
|
"\n",
|
||||||
|
"1. `where_cat_is_hiding`: A tool that uses an LLM to tell us where the cat is hiding\n",
|
||||||
|
"2. `tell_me_a_joke_about`: A tool that can use an LLM to tell a joke about the given topic\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"## ⚠️ Beta API ⚠️ ##\n",
|
||||||
|
"\n",
|
||||||
|
"Event Streaming is a **beta** API, and may change a bit based on feedback.\n",
|
||||||
|
"\n",
|
||||||
|
"Keep in mind the following constraints (repeated in tools section):\n",
|
||||||
|
"\n",
|
||||||
|
"* streaming only works properly if using `async`\n",
|
||||||
|
"* propagate callbacks if definning custom functions / runnables\n",
|
||||||
|
"* If creating a tool that uses an LLM, make sure to use `.astream()` on the LLM rather than `.ainvoke` to ask the LLM to stream tokens.\n",
|
||||||
|
"\n",
|
||||||
|
"## Event Hooks Reference\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Here is a reference table that shows some events that might be emitted by the various Runnable objects.\n",
|
||||||
|
"Definitions for some of the Runnable are included after the table.\n",
|
||||||
|
"\n",
|
||||||
|
"⚠️ When streaming the inputs for the runnable will not be available until the input stream has been entirely consumed This means that the inputs will be available at for the corresponding `end` hook rather than `start` event.\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"| event | name | chunk | input | output |\n",
|
||||||
|
"|----------------------|------------------|---------------------------------|-----------------------------------------------|-------------------------------------------------|\n",
|
||||||
|
"| on_chat_model_start | [model name] | | {\"messages\": [[SystemMessage, HumanMessage]]} | |\n",
|
||||||
|
"| on_chat_model_stream | [model name] | AIMessageChunk(content=\"hello\") | | |\n",
|
||||||
|
"| on_chat_model_end | [model name] | | {\"messages\": [[SystemMessage, HumanMessage]]} | {\"generations\": [...], \"llm_output\": None, ...} |\n",
|
||||||
|
"| on_llm_start | [model name] | | {'input': 'hello'} | |\n",
|
||||||
|
"| on_llm_stream | [model name] | 'Hello' | | |\n",
|
||||||
|
"| on_llm_end | [model name] | | 'Hello human!' |\n",
|
||||||
|
"| on_chain_start | format_docs | | | |\n",
|
||||||
|
"| on_chain_stream | format_docs | \"hello world!, goodbye world!\" | | |\n",
|
||||||
|
"| on_chain_end | format_docs | | [Document(...)] | \"hello world!, goodbye world!\" |\n",
|
||||||
|
"| on_tool_start | some_tool | | {\"x\": 1, \"y\": \"2\"} | |\n",
|
||||||
|
"| on_tool_stream | some_tool | {\"x\": 1, \"y\": \"2\"} | | |\n",
|
||||||
|
"| on_tool_end | some_tool | | | {\"x\": 1, \"y\": \"2\"} |\n",
|
||||||
|
"| on_retriever_start | [retriever name] | | {\"query\": \"hello\"} | |\n",
|
||||||
|
"| on_retriever_chunk | [retriever name] | {documents: [...]} | | |\n",
|
||||||
|
"| on_retriever_end | [retriever name] | | {\"query\": \"hello\"} | {documents: [...]} |\n",
|
||||||
|
"| on_prompt_start | [template_name] | | {\"question\": \"hello\"} | |\n",
|
||||||
|
"| on_prompt_end | [template_name] | | {\"question\": \"hello\"} | ChatPromptValue(messages: [SystemMessage, ...]) |\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Here are declarations associated with the events shown above:\n",
|
||||||
|
"\n",
|
||||||
|
"`format_docs`:\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"def format_docs(docs: List[Document]) -> str:\n",
|
||||||
|
" '''Format the docs.'''\n",
|
||||||
|
" return \", \".join([doc.page_content for doc in docs])\n",
|
||||||
|
"\n",
|
||||||
|
"format_docs = RunnableLambda(format_docs)\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"`some_tool`:\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"@tool\n",
|
||||||
|
"def some_tool(x: int, y: str) -> dict:\n",
|
||||||
|
" '''Some_tool.'''\n",
|
||||||
|
" return {\"x\": x, \"y\": y}\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"`prompt`:\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"template = ChatPromptTemplate.from_messages(\n",
|
||||||
|
" [(\"system\", \"You are Cat Agent 007\"), (\"human\", \"{question}\")]\n",
|
||||||
|
").with_config({\"run_name\": \"my_template\", \"tags\": [\"my_template\"]})\n",
|
||||||
|
"```\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "29205bef-2288-48e9-9067-f19072277a97",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain import hub\n",
|
||||||
|
"from langchain.agents import AgentExecutor, create_openai_tools_agent\n",
|
||||||
|
"from langchain.tools import tool\n",
|
||||||
|
"from langchain_core.callbacks import Callbacks\n",
|
||||||
|
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||||
|
"from langchain_openai import ChatOpenAI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d6b0fafa-ce3b-489b-bf1d-d37b87f4819e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create the model\n",
|
||||||
|
"\n",
|
||||||
|
"**Attention** For older versions of langchain, we must set `streaming=True`"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "fa3c3761-a1cd-4118-8559-ea4d8857d394",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = ChatOpenAI(temperature=0, streaming=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b76e1a3b-2983-42d9-ac12-4a0f32cd4a24",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Tools\n",
|
||||||
|
"\n",
|
||||||
|
"We define two tools that rely on a chat model to generate output!\n",
|
||||||
|
"\n",
|
||||||
|
"Please note a few different things:\n",
|
||||||
|
"\n",
|
||||||
|
"1. The tools are **async**\n",
|
||||||
|
"1. The model is invoked using **.astream()** to force the output to stream\n",
|
||||||
|
"1. For older langchain versions you should set `streaming=True` on the model!\n",
|
||||||
|
"1. We attach tags to the model so that we can filter on said tags in our callback handler\n",
|
||||||
|
"1. The tools accept callbacks and propagate them to the model as a runtime argument"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "c767f760-fe52-47e5-9c2a-622f03507aaf",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"@tool\n",
|
||||||
|
"async def where_cat_is_hiding(callbacks: Callbacks) -> str: # <--- Accept callbacks\n",
|
||||||
|
" \"\"\"Where is the cat hiding right now?\"\"\"\n",
|
||||||
|
" chunks = [\n",
|
||||||
|
" chunk\n",
|
||||||
|
" async for chunk in model.astream(\n",
|
||||||
|
" \"Give one up to three word answer about where the cat might be hiding in the house right now.\",\n",
|
||||||
|
" {\n",
|
||||||
|
" \"tags\": [\"tool_llm\"],\n",
|
||||||
|
" \"callbacks\": callbacks,\n",
|
||||||
|
" }, # <--- Propagate callbacks and assign a tag to this model\n",
|
||||||
|
" )\n",
|
||||||
|
" ]\n",
|
||||||
|
" return \"\".join(chunk.content for chunk in chunks)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"@tool\n",
|
||||||
|
"async def tell_me_a_joke_about(\n",
|
||||||
|
" topic: str, callbacks: Callbacks\n",
|
||||||
|
") -> str: # <--- Accept callbacks\n",
|
||||||
|
" \"\"\"Tell a joke about a given topic.\"\"\"\n",
|
||||||
|
" template = ChatPromptTemplate.from_messages(\n",
|
||||||
|
" [\n",
|
||||||
|
" (\"system\", \"You are Cat Agent 007. You are funny and know many jokes.\"),\n",
|
||||||
|
" (\"human\", \"Tell me a long joke about {topic}\"),\n",
|
||||||
|
" ]\n",
|
||||||
|
" )\n",
|
||||||
|
" chain = template | model.with_config({\"tags\": [\"tool_llm\"]})\n",
|
||||||
|
" chunks = [\n",
|
||||||
|
" chunk\n",
|
||||||
|
" async for chunk in chain.astream({\"topic\": topic}, {\"callbacks\": callbacks})\n",
|
||||||
|
" ]\n",
|
||||||
|
" return \"\".join(chunk.content for chunk in chunks)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "cba476f8-29da-4c2c-9134-186871caf7ae",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Initialize the Agent"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "0bab4488-bf4c-461f-b41e-5e60310fe0f2",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"input_variables=['agent_scratchpad', 'input'] input_types={'chat_history': typing.List[typing.Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]], 'agent_scratchpad': typing.List[typing.Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]} messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a helpful assistant')), MessagesPlaceholder(variable_name='chat_history', optional=True), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], template='{input}')), MessagesPlaceholder(variable_name='agent_scratchpad')]\n",
|
||||||
|
"[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a helpful assistant')), MessagesPlaceholder(variable_name='chat_history', optional=True), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], template='{input}')), MessagesPlaceholder(variable_name='agent_scratchpad')]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Get the prompt to use - you can modify this!\n",
|
||||||
|
"prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n",
|
||||||
|
"print(prompt)\n",
|
||||||
|
"print(prompt.messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "1762f4e1-402a-4bfb-af26-eb5b7b8f56bd",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tools = [tell_me_a_joke_about, where_cat_is_hiding]\n",
|
||||||
|
"agent = create_openai_tools_agent(model.with_config({\"tags\": [\"agent\"]}), tools, prompt)\n",
|
||||||
|
"executor = AgentExecutor(agent=agent, tools=tools)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "841271d7-1de1-41a9-9387-bb04368537f1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Stream the output\n",
|
||||||
|
"\n",
|
||||||
|
"The streamed output is shown with a `|` as the delimiter between tokens. "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "a5d94bd8-4a55-4527-b21a-4245a38c7c26",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/eugene/src/langchain/libs/core/langchain_core/_api/beta_decorator.py:86: LangChainBetaWarning: This API is in beta and may change in the future.\n",
|
||||||
|
" warn_beta(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"--\n",
|
||||||
|
"Starting tool: where_cat_is_hiding with inputs: {}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"|Under| the| bed|.||\n",
|
||||||
|
"\n",
|
||||||
|
"Ended tool: where_cat_is_hiding\n",
|
||||||
|
"--\n",
|
||||||
|
"Starting tool: tell_me_a_joke_about with inputs: {'topic': 'under the bed'}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"|Sure|,| here|'s| a| long| joke| about| what|'s| hiding| under| the| bed|:\n",
|
||||||
|
"\n",
|
||||||
|
"|Once| upon| a| time|,| there| was| a| mis|chie|vous| little| boy| named| Tim|my|.| Tim|my| had| always| been| afraid| of| what| might| be| lurking| under| his| bed| at| night|.| Every| evening|,| he| would| ti|pt|oe| into| his| room|,| turn| off| the| lights|,| and| then| make| a| daring| leap| onto| his| bed|,| ensuring| that| nothing| could| grab| his| ankles|.\n",
|
||||||
|
"\n",
|
||||||
|
"|One| night|,| Tim|my|'s| parents| decided| to| play| a| prank| on| him|.| They| hid| a| remote|-controlled| toy| monster| under| his| bed|,| complete| with| glowing| eyes| and| a| grow|ling| sound| effect|.| As| Tim|my| settled| into| bed|,| his| parents| quietly| sn|uck| into| his| room|,| ready| to| give| him| the| scare| of| a| lifetime|.\n",
|
||||||
|
"\n",
|
||||||
|
"|Just| as| Tim|my| was| about| to| drift| off| to| sleep|,| he| heard| a| faint| grow|l| coming| from| under| his| bed|.| His| eyes| widened| with| fear|,| and| his| heart| started| racing|.| He| must|ered| up| the| courage| to| peek| under| the| bed|,| and| to| his| surprise|,| he| saw| a| pair| of| glowing| eyes| staring| back| at| him|.\n",
|
||||||
|
"\n",
|
||||||
|
"|Terr|ified|,| Tim|my| jumped| out| of| bed| and| ran| to| his| parents|,| screaming|,| \"|There|'s| a| monster| under| my| bed|!| Help|!\"\n",
|
||||||
|
"\n",
|
||||||
|
"|His| parents|,| trying| to| st|ifle| their| laughter|,| rushed| into| his| room|.| They| pretended| to| be| just| as| scared| as| Tim|my|,| and| together|,| they| brav|ely| approached| the| bed|.| Tim|my|'s| dad| grabbed| a| bro|om|stick|,| ready| to| defend| his| family| against| the| imaginary| monster|.\n",
|
||||||
|
"\n",
|
||||||
|
"|As| they| got| closer|,| the| \"|monster|\"| under| the| bed| started| to| move|.| Tim|my|'s| mom|,| unable| to| contain| her| laughter| any| longer|,| pressed| a| button| on| the| remote| control|,| causing| the| toy| monster| to| sc|urry| out| from| under| the| bed|.| Tim|my|'s| fear| quickly| turned| into| confusion|,| and| then| into| laughter| as| he| realized| it| was| all| just| a| prank|.\n",
|
||||||
|
"\n",
|
||||||
|
"|From| that| day| forward|,| Tim|my| learned| that| sometimes| the| things| we| fear| the| most| are| just| fig|ments| of| our| imagination|.| And| as| for| what|'s| hiding| under| his| bed|?| Well|,| it|'s| just| dust| b|unn|ies| and| the| occasional| missing| sock|.| Nothing| to| be| afraid| of|!\n",
|
||||||
|
"\n",
|
||||||
|
"|Remember|,| laughter| is| the| best| monster| repell|ent|!||\n",
|
||||||
|
"\n",
|
||||||
|
"Ended tool: tell_me_a_joke_about\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"async for event in executor.astream_events(\n",
|
||||||
|
" {\"input\": \"where is the cat hiding? Tell me a joke about that location?\"},\n",
|
||||||
|
" include_tags=[\"tool_llm\"],\n",
|
||||||
|
" include_types=[\"tool\"],\n",
|
||||||
|
"):\n",
|
||||||
|
" hook = event[\"event\"]\n",
|
||||||
|
" if hook == \"on_chat_model_stream\":\n",
|
||||||
|
" print(event[\"data\"][\"chunk\"].content, end=\"|\")\n",
|
||||||
|
" elif hook in {\"on_chat_model_start\", \"on_chat_model_end\"}:\n",
|
||||||
|
" print()\n",
|
||||||
|
" print()\n",
|
||||||
|
" elif hook == \"on_tool_start\":\n",
|
||||||
|
" print(\"--\")\n",
|
||||||
|
" print(\n",
|
||||||
|
" f\"Starting tool: {event['name']} with inputs: {event['data'].get('input')}\"\n",
|
||||||
|
" )\n",
|
||||||
|
" elif hook == \"on_tool_end\":\n",
|
||||||
|
" print(f\"Ended tool: {event['name']}\")\n",
|
||||||
|
" else:\n",
|
||||||
|
" pass"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -219,6 +219,7 @@ class CallbackManagerMixin:
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
@ -409,6 +410,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
|
@ -1282,15 +1282,22 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
input_str: str,
|
input_str: str,
|
||||||
run_id: Optional[UUID] = None,
|
run_id: Optional[UUID] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> CallbackManagerForToolRun:
|
) -> CallbackManagerForToolRun:
|
||||||
"""Run when tool starts running.
|
"""Run when tool starts running.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
serialized (Dict[str, Any]): The serialized tool.
|
serialized: Serialized representation of the tool.
|
||||||
input_str (str): The input to the tool.
|
input_str: The input to the tool as a string.
|
||||||
run_id (UUID, optional): The ID of the run. Defaults to None.
|
Non-string inputs are cast to strings.
|
||||||
parent_run_id (UUID, optional): The ID of the parent run. Defaults to None.
|
run_id: ID for the run. Defaults to None.
|
||||||
|
parent_run_id: The ID of the parent run. Defaults to None.
|
||||||
|
inputs: The original input to the tool if provided.
|
||||||
|
Recommended for usage instead of input_str when the original
|
||||||
|
input is needed.
|
||||||
|
If provided, the inputs are expected to be formatted as a dict.
|
||||||
|
The keys will correspond to the named-arguments in the tool.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CallbackManagerForToolRun: The callback manager for the tool run.
|
CallbackManagerForToolRun: The callback manager for the tool run.
|
||||||
@ -1308,6 +1315,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
tags=self.tags,
|
tags=self.tags,
|
||||||
metadata=self.metadata,
|
metadata=self.metadata,
|
||||||
|
inputs=inputs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ import threading
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import FIRST_COMPLETED, wait
|
from concurrent.futures import FIRST_COMPLETED, wait
|
||||||
from contextvars import copy_context
|
from contextvars import copy_context
|
||||||
from copy import deepcopy
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from itertools import groupby, tee
|
from itertools import groupby, tee
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -36,7 +35,8 @@ from typing import (
|
|||||||
|
|
||||||
from typing_extensions import Literal, get_args
|
from typing_extensions import Literal, get_args
|
||||||
|
|
||||||
from langchain_core.load.dump import dumpd, dumps
|
from langchain_core._api import beta_decorator
|
||||||
|
from langchain_core.load.dump import dumpd
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.pydantic_v1 import BaseConfig, BaseModel, Field, create_model
|
from langchain_core.pydantic_v1 import BaseConfig, BaseModel, Field, create_model
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
@ -54,6 +54,7 @@ from langchain_core.runnables.config import (
|
|||||||
var_child_runnable_config,
|
var_child_runnable_config,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.graph import Graph
|
from langchain_core.runnables.graph import Graph
|
||||||
|
from langchain_core.runnables.schema import EventData, StreamEvent
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
AddableDict,
|
AddableDict,
|
||||||
AnyConfigurableField,
|
AnyConfigurableField,
|
||||||
@ -83,7 +84,11 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.runnables.fallbacks import (
|
from langchain_core.runnables.fallbacks import (
|
||||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||||
)
|
)
|
||||||
from langchain_core.tracers.log_stream import RunLog, RunLogPatch
|
from langchain_core.tracers.log_stream import (
|
||||||
|
LogEntry,
|
||||||
|
RunLog,
|
||||||
|
RunLogPatch,
|
||||||
|
)
|
||||||
from langchain_core.tracers.root_listeners import Listener
|
from langchain_core.tracers.root_listeners import Listener
|
||||||
|
|
||||||
|
|
||||||
@ -600,7 +605,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
exclude_names: Optional[Sequence[str]] = None,
|
exclude_names: Optional[Sequence[str]] = None,
|
||||||
exclude_types: Optional[Sequence[str]] = None,
|
exclude_types: Optional[Sequence[str]] = None,
|
||||||
exclude_tags: Optional[Sequence[str]] = None,
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[RunLogPatch]:
|
) -> AsyncIterator[RunLogPatch]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -618,7 +623,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
exclude_names: Optional[Sequence[str]] = None,
|
exclude_names: Optional[Sequence[str]] = None,
|
||||||
exclude_types: Optional[Sequence[str]] = None,
|
exclude_types: Optional[Sequence[str]] = None,
|
||||||
exclude_tags: Optional[Sequence[str]] = None,
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[RunLog]:
|
) -> AsyncIterator[RunLog]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -635,7 +640,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
exclude_names: Optional[Sequence[str]] = None,
|
exclude_names: Optional[Sequence[str]] = None,
|
||||||
exclude_types: Optional[Sequence[str]] = None,
|
exclude_types: Optional[Sequence[str]] = None,
|
||||||
exclude_tags: Optional[Sequence[str]] = None,
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Any,
|
||||||
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
||||||
"""
|
"""
|
||||||
Stream all output from a runnable, as reported to the callback system.
|
Stream all output from a runnable, as reported to the callback system.
|
||||||
@ -659,16 +664,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
exclude_types: Exclude logs with these types.
|
exclude_types: Exclude logs with these types.
|
||||||
exclude_tags: Exclude logs with these tags.
|
exclude_tags: Exclude logs with these tags.
|
||||||
"""
|
"""
|
||||||
import jsonpatch # type: ignore[import]
|
|
||||||
|
|
||||||
from langchain_core.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain_core.tracers.log_stream import (
|
from langchain_core.tracers.log_stream import (
|
||||||
LogStreamCallbackHandler,
|
LogStreamCallbackHandler,
|
||||||
RunLog,
|
_astream_log_implementation,
|
||||||
RunLogPatch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a stream handler that will emit Log objects
|
|
||||||
stream = LogStreamCallbackHandler(
|
stream = LogStreamCallbackHandler(
|
||||||
auto_close=False,
|
auto_close=False,
|
||||||
include_names=include_names,
|
include_names=include_names,
|
||||||
@ -677,82 +677,336 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
exclude_names=exclude_names,
|
exclude_names=exclude_names,
|
||||||
exclude_types=exclude_types,
|
exclude_types=exclude_types,
|
||||||
exclude_tags=exclude_tags,
|
exclude_tags=exclude_tags,
|
||||||
|
_schema_format="original",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign the stream handler to the config
|
# Mypy isn't resolving the overloads here
|
||||||
config = ensure_config(config)
|
# Likely an issue b/c `self` is being passed through
|
||||||
callbacks = config.get("callbacks")
|
# and it's can't map it to Runnable[Input,Output]?
|
||||||
if callbacks is None:
|
async for item in _astream_log_implementation( # type: ignore
|
||||||
config["callbacks"] = [stream]
|
self,
|
||||||
elif isinstance(callbacks, list):
|
input,
|
||||||
config["callbacks"] = callbacks + [stream]
|
config,
|
||||||
elif isinstance(callbacks, BaseCallbackManager):
|
diff=diff,
|
||||||
callbacks = callbacks.copy()
|
stream=stream,
|
||||||
callbacks.add_handler(stream, inherit=True)
|
with_streamed_output_list=with_streamed_output_list,
|
||||||
config["callbacks"] = callbacks
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unexpected type for callbacks: {callbacks}."
|
|
||||||
"Expected None, list or AsyncCallbackManager."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call the runnable in streaming mode,
|
|
||||||
# add each chunk to the output stream
|
|
||||||
async def consume_astream() -> None:
|
|
||||||
try:
|
|
||||||
prev_final_output: Optional[Output] = None
|
|
||||||
final_output: Optional[Output] = None
|
|
||||||
|
|
||||||
async for chunk in self.astream(input, config, **kwargs):
|
|
||||||
prev_final_output = final_output
|
|
||||||
if final_output is None:
|
|
||||||
final_output = chunk
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
final_output = final_output + chunk # type: ignore
|
|
||||||
except TypeError:
|
|
||||||
final_output = chunk
|
|
||||||
patches: List[Dict[str, Any]] = []
|
|
||||||
if with_streamed_output_list:
|
|
||||||
patches.append(
|
|
||||||
{
|
|
||||||
"op": "add",
|
|
||||||
"path": "/streamed_output/-",
|
|
||||||
# chunk cannot be shared between
|
|
||||||
# streamed_output and final_output
|
|
||||||
# otherwise jsonpatch.apply will
|
|
||||||
# modify both
|
|
||||||
"value": deepcopy(chunk),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for op in jsonpatch.JsonPatch.from_diff(
|
|
||||||
prev_final_output, final_output, dumps=dumps
|
|
||||||
):
|
):
|
||||||
patches.append({**op, "path": f"/final_output{op['path']}"})
|
yield item
|
||||||
await stream.send_stream.send(RunLogPatch(*patches))
|
|
||||||
finally:
|
|
||||||
await stream.send_stream.aclose()
|
|
||||||
|
|
||||||
# Start the runnable in a task, so we can start consuming output
|
@beta_decorator.beta(message="This API is in beta and may change in the future.")
|
||||||
task = asyncio.create_task(consume_astream())
|
async def astream_events(
|
||||||
|
self,
|
||||||
|
input: Any,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
include_names: Optional[Sequence[str]] = None,
|
||||||
|
include_types: Optional[Sequence[str]] = None,
|
||||||
|
include_tags: Optional[Sequence[str]] = None,
|
||||||
|
exclude_names: Optional[Sequence[str]] = None,
|
||||||
|
exclude_types: Optional[Sequence[str]] = None,
|
||||||
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
"""Generate a stream of events.
|
||||||
|
|
||||||
try:
|
Use to create an iterator ove StreamEvents that provide real-time information
|
||||||
# Yield each chunk from the output stream
|
about the progress of the runnable, including StreamEvents from intermediate
|
||||||
if diff:
|
results.
|
||||||
async for log in stream:
|
|
||||||
yield log
|
A StreamEvent is a dictionary with the following schema:
|
||||||
|
|
||||||
|
* ``event``: str - Event names are of the
|
||||||
|
format: on_[runnable_type]_(start|stream|end).
|
||||||
|
* ``name``: str - The name of the runnable that generated the event.
|
||||||
|
* ``run_id``: str - randomly generated ID associated with the given execution of
|
||||||
|
the runnable that emitted the event.
|
||||||
|
A child runnable that gets invoked as part of the execution of a
|
||||||
|
parent runnable is assigned its own unique ID.
|
||||||
|
* ``tags``: Optional[List[str]] - The tags of the runnable that generated
|
||||||
|
the event.
|
||||||
|
* ``metadata``: Optional[Dict[str, Any]] - The metadata of the runnable
|
||||||
|
that generated the event.
|
||||||
|
* ``data``: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
Below is a table that illustrates some evens that might be emitted by various
|
||||||
|
chains. Metadata fields have been omitted from the table for brevity.
|
||||||
|
Chain definitions have been included after the table.
|
||||||
|
|
||||||
|
| event | name | chunk | input | output |
|
||||||
|
|----------------------|------------------|---------------------------------|-----------------------------------------------|-------------------------------------------------|
|
||||||
|
| on_chat_model_start | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | |
|
||||||
|
| on_chat_model_stream | [model name] | AIMessageChunk(content="hello") | | |
|
||||||
|
| on_chat_model_end | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | {"generations": [...], "llm_output": None, ...} |
|
||||||
|
| on_llm_start | [model name] | | {'input': 'hello'} | |
|
||||||
|
| on_llm_stream | [model name] | 'Hello' | | |
|
||||||
|
| on_llm_end | [model name] | | 'Hello human!' |
|
||||||
|
| on_chain_start | format_docs | | | |
|
||||||
|
| on_chain_stream | format_docs | "hello world!, goodbye world!" | | |
|
||||||
|
| on_chain_end | format_docs | | [Document(...)] | "hello world!, goodbye world!" |
|
||||||
|
| on_tool_start | some_tool | | {"x": 1, "y": "2"} | |
|
||||||
|
| on_tool_stream | some_tool | {"x": 1, "y": "2"} | | |
|
||||||
|
| on_tool_end | some_tool | | | {"x": 1, "y": "2"} |
|
||||||
|
| on_retriever_start | [retriever name] | | {"query": "hello"} | |
|
||||||
|
| on_retriever_chunk | [retriever name] | {documents: [...]} | | |
|
||||||
|
| on_retriever_end | [retriever name] | | {"query": "hello"} | {documents: [...]} |
|
||||||
|
| on_prompt_start | [template_name] | | {"question": "hello"} | |
|
||||||
|
| on_prompt_end | [template_name] | | {"question": "hello"} | ChatPromptValue(messages: [SystemMessage, ...]) |
|
||||||
|
|
||||||
|
Here are declarations associated with the events shown above:
|
||||||
|
|
||||||
|
`format_docs`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def format_docs(docs: List[Document]) -> str:
|
||||||
|
'''Format the docs.'''
|
||||||
|
return ", ".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
|
format_docs = RunnableLambda(format_docs)
|
||||||
|
```
|
||||||
|
|
||||||
|
`some_tool`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@tool
|
||||||
|
def some_tool(x: int, y: str) -> dict:
|
||||||
|
'''Some_tool.'''
|
||||||
|
return {"x": x, "y": y}
|
||||||
|
```
|
||||||
|
|
||||||
|
`prompt`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
template = ChatPromptTemplate.from_messages(
|
||||||
|
[("system", "You are Cat Agent 007"), ("human", "{question}")]
|
||||||
|
).with_config({"run_name": "my_template", "tags": ["my_template"]})
|
||||||
|
```
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.runnables import RunnableLambda
|
||||||
|
|
||||||
|
async def reverse(s: str) -> str:
|
||||||
|
return s[::-1]
|
||||||
|
|
||||||
|
chain = RunnableLambda(func=reverse)
|
||||||
|
|
||||||
|
events = [event async for event in chain.astream_events("hello")]
|
||||||
|
|
||||||
|
# will produce the following events (run_id has been omitted for brevity):
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"data": {"input": "hello"},
|
||||||
|
"event": "on_chain_start",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "reverse",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": "olleh"},
|
||||||
|
"event": "on_chain_stream",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "reverse",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"output": "olleh"},
|
||||||
|
"event": "on_chain_end",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "reverse",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input to the runnable.
|
||||||
|
config: The config to use for the runnable.
|
||||||
|
include_names: Only include events from runnables with matching names.
|
||||||
|
include_types: Only include events from runnables with matching types.
|
||||||
|
include_tags: Only include events from runnables with matching tags.
|
||||||
|
exclude_names: Exclude events from runnables with matching names.
|
||||||
|
exclude_types: Exclude events from runnables with matching types.
|
||||||
|
exclude_tags: Exclude events from runnables with matching tags.
|
||||||
|
kwargs: Additional keyword arguments to pass to the runnable.
|
||||||
|
These will be passed to astream_log as this implementation
|
||||||
|
of astream_events is built on top of astream_log.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An async stream of StreamEvents.
|
||||||
|
""" # noqa: E501
|
||||||
|
from langchain_core.runnables.utils import (
|
||||||
|
_RootEventFilter,
|
||||||
|
)
|
||||||
|
from langchain_core.tracers.log_stream import (
|
||||||
|
LogStreamCallbackHandler,
|
||||||
|
RunLog,
|
||||||
|
_astream_log_implementation,
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = LogStreamCallbackHandler(
|
||||||
|
auto_close=False,
|
||||||
|
include_names=include_names,
|
||||||
|
include_types=include_types,
|
||||||
|
include_tags=include_tags,
|
||||||
|
exclude_names=exclude_names,
|
||||||
|
exclude_types=exclude_types,
|
||||||
|
exclude_tags=exclude_tags,
|
||||||
|
_schema_format="streaming_events",
|
||||||
|
)
|
||||||
|
|
||||||
|
run_log = RunLog(state=None) # type: ignore[arg-type]
|
||||||
|
encountered_start_event = False
|
||||||
|
|
||||||
|
_root_event_filter = _RootEventFilter(
|
||||||
|
include_names=include_names,
|
||||||
|
include_types=include_types,
|
||||||
|
include_tags=include_tags,
|
||||||
|
exclude_names=exclude_names,
|
||||||
|
exclude_types=exclude_types,
|
||||||
|
exclude_tags=exclude_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = ensure_config(config)
|
||||||
|
root_tags = config.get("tags", [])
|
||||||
|
root_metadata = config.get("metadata", {})
|
||||||
|
root_name = config.get("run_name", self.get_name())
|
||||||
|
|
||||||
|
# Ignoring mypy complaint about too many different union combinations
|
||||||
|
# This arises because many of the argument types are unions
|
||||||
|
async for log in _astream_log_implementation( # type: ignore[misc]
|
||||||
|
self,
|
||||||
|
input,
|
||||||
|
config=config,
|
||||||
|
stream=stream,
|
||||||
|
diff=True,
|
||||||
|
with_streamed_output_list=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
run_log = run_log + log
|
||||||
|
|
||||||
|
if not encountered_start_event:
|
||||||
|
# Yield the start event for the root runnable.
|
||||||
|
encountered_start_event = True
|
||||||
|
state = run_log.state.copy()
|
||||||
|
|
||||||
|
event = StreamEvent(
|
||||||
|
event=f"on_{state['type']}_start",
|
||||||
|
run_id=state["id"],
|
||||||
|
name=root_name,
|
||||||
|
tags=root_tags,
|
||||||
|
metadata=root_metadata,
|
||||||
|
data={
|
||||||
|
"input": input,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if _root_event_filter.include_event(event, state["type"]):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
paths = {
|
||||||
|
op["path"].split("/")[2]
|
||||||
|
for op in log.ops
|
||||||
|
if op["path"].startswith("/logs/")
|
||||||
|
}
|
||||||
|
# Elements in a set should be iterated in the same order
|
||||||
|
# as they were inserted in modern python versions.
|
||||||
|
for path in paths:
|
||||||
|
data: EventData = {}
|
||||||
|
log_entry: LogEntry = run_log.state["logs"][path]
|
||||||
|
if log_entry["end_time"] is None:
|
||||||
|
if log_entry["streamed_output"]:
|
||||||
|
event_type = "stream"
|
||||||
else:
|
else:
|
||||||
state = RunLog(state=None) # type: ignore[arg-type]
|
event_type = "start"
|
||||||
async for log in stream:
|
else:
|
||||||
state = state + log
|
event_type = "end"
|
||||||
yield state
|
|
||||||
finally:
|
if event_type == "start":
|
||||||
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
# Include the inputs with the start event if they are available.
|
||||||
try:
|
# Usually they will NOT be available for components that operate
|
||||||
await task
|
# on streams, since those components stream the input and
|
||||||
except asyncio.CancelledError:
|
# don't know its final value until the end of the stream.
|
||||||
|
inputs = log_entry["inputs"]
|
||||||
|
if inputs is not None:
|
||||||
|
data["input"] = inputs
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if event_type == "end":
|
||||||
|
inputs = log_entry["inputs"]
|
||||||
|
if inputs is not None:
|
||||||
|
data["input"] = inputs
|
||||||
|
|
||||||
|
# None is a VALID output for an end event
|
||||||
|
data["output"] = log_entry["final_output"]
|
||||||
|
|
||||||
|
if event_type == "stream":
|
||||||
|
num_chunks = len(log_entry["streamed_output"])
|
||||||
|
if num_chunks != 1:
|
||||||
|
raise AssertionError(
|
||||||
|
f"Expected exactly one chunk of streamed output, "
|
||||||
|
f"got {num_chunks} instead. This is impossible. "
|
||||||
|
f"Encountered in: {log_entry['name']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {"chunk": log_entry["streamed_output"][0]}
|
||||||
|
# Clean up the stream, we don't need it anymore.
|
||||||
|
# And this avoids duplicates as well!
|
||||||
|
log_entry["streamed_output"] = []
|
||||||
|
|
||||||
|
yield StreamEvent(
|
||||||
|
event=f"on_{log_entry['type']}_{event_type}",
|
||||||
|
name=log_entry["name"],
|
||||||
|
run_id=log_entry["id"],
|
||||||
|
tags=log_entry["tags"],
|
||||||
|
metadata=log_entry["metadata"],
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally, we take care of the streaming output from the root chain
|
||||||
|
# if there is any.
|
||||||
|
state = run_log.state
|
||||||
|
if state["streamed_output"]:
|
||||||
|
num_chunks = len(state["streamed_output"])
|
||||||
|
if num_chunks != 1:
|
||||||
|
raise AssertionError(
|
||||||
|
f"Expected exactly one chunk of streamed output, "
|
||||||
|
f"got {num_chunks} instead. This is impossible. "
|
||||||
|
f"Encountered in: {state['name']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {"chunk": state["streamed_output"][0]}
|
||||||
|
# Clean up the stream, we don't need it anymore.
|
||||||
|
state["streamed_output"] = []
|
||||||
|
|
||||||
|
event = StreamEvent(
|
||||||
|
event=f"on_{state['type']}_stream",
|
||||||
|
run_id=state["id"],
|
||||||
|
tags=root_tags,
|
||||||
|
metadata=root_metadata,
|
||||||
|
name=root_name,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
if _root_event_filter.include_event(event, state["type"]):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
state = run_log.state
|
||||||
|
|
||||||
|
# Finally yield the end event for the root runnable.
|
||||||
|
event = StreamEvent(
|
||||||
|
event=f"on_{state['type']}_end",
|
||||||
|
name=root_name,
|
||||||
|
run_id=state["id"],
|
||||||
|
tags=root_tags,
|
||||||
|
metadata=root_metadata,
|
||||||
|
data={
|
||||||
|
"output": state["final_output"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if _root_event_filter.include_event(event, state["type"]):
|
||||||
|
yield event
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
input: Iterator[Input],
|
||||||
@ -3396,6 +3650,18 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||||
|
|
||||||
|
async def astream_events(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
for _ in range(1):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"RunnableEach does not support astream_events yet."
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
class RunnableEach(RunnableEachBase[Input, Output]):
|
class RunnableEach(RunnableEachBase[Input, Output]):
|
||||||
"""
|
"""
|
||||||
@ -3686,6 +3952,17 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
async def astream_events(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
async for item in self.bound.astream_events(
|
||||||
|
input, self._merge_configs(config), **{**self.kwargs, **kwargs}
|
||||||
|
):
|
||||||
|
yield item
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
input: Iterator[Input],
|
||||||
|
133
libs/core/langchain_core/runnables/schema.py
Normal file
133
libs/core/langchain_core/runnables/schema.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
"""Module contains typedefs that are used with runnables."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class EventData(TypedDict, total=False):
|
||||||
|
"""Data associated with a streaming event."""
|
||||||
|
|
||||||
|
input: Any
|
||||||
|
"""The input passed to the runnable that generated the event.
|
||||||
|
|
||||||
|
Inputs will sometimes be available at the *START* of the runnable, and
|
||||||
|
sometimes at the *END* of the runnable.
|
||||||
|
|
||||||
|
If a runnable is able to stream its inputs, then its input by definition
|
||||||
|
won't be known until the *END* of the runnable when it has finished streaming
|
||||||
|
its inputs.
|
||||||
|
"""
|
||||||
|
output: Any
|
||||||
|
"""The output of the runnable that generated the event.
|
||||||
|
|
||||||
|
Outputs will only be available at the *END* of the runnable.
|
||||||
|
|
||||||
|
For most runnables, this field can be inferred from the `chunk` field,
|
||||||
|
though there might be some exceptions for special cased runnables (e.g., like
|
||||||
|
chat models), which may return more information.
|
||||||
|
"""
|
||||||
|
chunk: Any
|
||||||
|
"""A streaming chunk from the output that generated the event.
|
||||||
|
|
||||||
|
chunks support addition in general, and adding them up should result
|
||||||
|
in the output of the runnable that generated the event.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class StreamEvent(TypedDict):
|
||||||
|
"""A streaming event.
|
||||||
|
|
||||||
|
Schema of a streaming event which is produced from the astream_events method.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.runnables import RunnableLambda
|
||||||
|
|
||||||
|
async def reverse(s: str) -> str:
|
||||||
|
return s[::-1]
|
||||||
|
|
||||||
|
chain = RunnableLambda(func=reverse)
|
||||||
|
|
||||||
|
events = [event async for event in chain.astream_events("hello")]
|
||||||
|
|
||||||
|
# will produce the following events (run_id has been omitted for brevity):
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"data": {"input": "hello"},
|
||||||
|
"event": "on_chain_start",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "reverse",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": "olleh"},
|
||||||
|
"event": "on_chain_stream",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "reverse",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"output": "olleh"},
|
||||||
|
"event": "on_chain_end",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "reverse",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
event: str
|
||||||
|
"""Event names are of the format: on_[runnable_type]_(start|stream|end).
|
||||||
|
|
||||||
|
Runnable types are one of:
|
||||||
|
* llm - used by non chat models
|
||||||
|
* chat_model - used by chat models
|
||||||
|
* prompt -- e.g., ChatPromptTemplate
|
||||||
|
* tool -- from tools defined via @tool decorator or inheriting from Tool/BaseTool
|
||||||
|
* chain - most Runnables are of this type
|
||||||
|
|
||||||
|
Further, the events are categorized as one of:
|
||||||
|
* start - when the runnable starts
|
||||||
|
* stream - when the runnable is streaming
|
||||||
|
* end - when the runnable ends
|
||||||
|
|
||||||
|
start, stream and end are associated with slightly different `data` payload.
|
||||||
|
|
||||||
|
Please see the documentation for `EventData` for more details.
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
"""The name of the runnable that generated the event."""
|
||||||
|
run_id: str
|
||||||
|
"""An randomly generated ID to keep track of the execution of the given runnable.
|
||||||
|
|
||||||
|
Each child runnable that gets invoked as part of the execution of a parent runnable
|
||||||
|
is assigned its own unique ID.
|
||||||
|
"""
|
||||||
|
tags: NotRequired[List[str]]
|
||||||
|
"""Tags associated with the runnable that generated this event.
|
||||||
|
|
||||||
|
Tags are always inherited from parent runnables.
|
||||||
|
|
||||||
|
Tags can either be bound to a runnable using `.with_config({"tags": ["hello"]})`
|
||||||
|
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
|
||||||
|
"""
|
||||||
|
metadata: NotRequired[Dict[str, Any]]
|
||||||
|
"""Metadata associated with the runnable that generated this event.
|
||||||
|
|
||||||
|
Metadata can either be bound to a runnable using
|
||||||
|
|
||||||
|
`.with_config({"metadata": { "foo": "bar" }})`
|
||||||
|
|
||||||
|
or passed at run time using
|
||||||
|
|
||||||
|
`.astream_events(..., {"metadata": {"foo": "bar"}})`.
|
||||||
|
"""
|
||||||
|
data: EventData
|
||||||
|
"""Event data.
|
||||||
|
|
||||||
|
The contents of the event data depend on the event type.
|
||||||
|
"""
|
@ -1,3 +1,4 @@
|
|||||||
|
"""Utility code for runnables."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
@ -24,6 +25,8 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
|
|
||||||
Input = TypeVar("Input", contravariant=True)
|
Input = TypeVar("Input", contravariant=True)
|
||||||
# Output type should implement __concat__, as eg str, list, dict do
|
# Output type should implement __concat__, as eg str, list, dict do
|
||||||
Output = TypeVar("Output", covariant=True)
|
Output = TypeVar("Output", covariant=True)
|
||||||
@ -419,3 +422,58 @@ def get_unique_config_specs(
|
|||||||
f"for {id}: {[first] + others}"
|
f"for {id}: {[first] + others}"
|
||||||
)
|
)
|
||||||
return unique
|
return unique
|
||||||
|
|
||||||
|
|
||||||
|
class _RootEventFilter:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
include_names: Optional[Sequence[str]] = None,
|
||||||
|
include_types: Optional[Sequence[str]] = None,
|
||||||
|
include_tags: Optional[Sequence[str]] = None,
|
||||||
|
exclude_names: Optional[Sequence[str]] = None,
|
||||||
|
exclude_types: Optional[Sequence[str]] = None,
|
||||||
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Utility to filter the root event in the astream_events implementation.
|
||||||
|
|
||||||
|
This is simply binding the arguments to the namespace to make save on
|
||||||
|
a bit of typing in the astream_events implementation.
|
||||||
|
"""
|
||||||
|
self.include_names = include_names
|
||||||
|
self.include_types = include_types
|
||||||
|
self.include_tags = include_tags
|
||||||
|
self.exclude_names = exclude_names
|
||||||
|
self.exclude_types = exclude_types
|
||||||
|
self.exclude_tags = exclude_tags
|
||||||
|
|
||||||
|
def include_event(self, event: StreamEvent, root_type: str) -> bool:
|
||||||
|
"""Determine whether to include an event."""
|
||||||
|
if (
|
||||||
|
self.include_names is None
|
||||||
|
and self.include_types is None
|
||||||
|
and self.include_tags is None
|
||||||
|
):
|
||||||
|
include = True
|
||||||
|
else:
|
||||||
|
include = False
|
||||||
|
|
||||||
|
event_tags = event.get("tags") or []
|
||||||
|
|
||||||
|
if self.include_names is not None:
|
||||||
|
include = include or event["name"] in self.include_names
|
||||||
|
if self.include_types is not None:
|
||||||
|
include = include or root_type in self.include_types
|
||||||
|
if self.include_tags is not None:
|
||||||
|
include = include or any(tag in self.include_tags for tag in event_tags)
|
||||||
|
|
||||||
|
if self.exclude_names is not None:
|
||||||
|
include = include and event["name"] not in self.exclude_names
|
||||||
|
if self.exclude_types is not None:
|
||||||
|
include = include and root_type not in self.exclude_types
|
||||||
|
if self.exclude_tags is not None:
|
||||||
|
include = include and all(
|
||||||
|
tag not in self.exclude_tags for tag in event_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
return include
|
||||||
|
@ -300,7 +300,7 @@ class ChildTool(BaseTool):
|
|||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
tool_input: Union[str, Dict],
|
tool_input: Union[str, Dict[str, Any]],
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
start_color: Optional[str] = "green",
|
start_color: Optional[str] = "green",
|
||||||
color: Optional[str] = "green",
|
color: Optional[str] = "green",
|
||||||
@ -333,6 +333,11 @@ class ChildTool(BaseTool):
|
|||||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||||
color=start_color,
|
color=start_color,
|
||||||
name=run_name,
|
name=run_name,
|
||||||
|
# Inputs by definition should always be dicts.
|
||||||
|
# For now, it's unclear whether this assumption is ever violated,
|
||||||
|
# but if it is we will send a `None` value to the callback instead
|
||||||
|
# And will need to address issue via a patch.
|
||||||
|
inputs=None if isinstance(tool_input, str) else tool_input,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -407,6 +412,7 @@ class ChildTool(BaseTool):
|
|||||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||||
color=start_color,
|
color=start_color,
|
||||||
name=run_name,
|
name=run_name,
|
||||||
|
inputs=tool_input,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
@ -11,8 +11,10 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Set,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -23,6 +25,7 @@ from tenacity import RetryCallState
|
|||||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||||
from langchain_core.exceptions import TracerException
|
from langchain_core.exceptions import TracerException
|
||||||
from langchain_core.load import dumpd
|
from langchain_core.load import dumpd
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
@ -40,8 +43,29 @@ logger = logging.getLogger(__name__)
|
|||||||
class BaseTracer(BaseCallbackHandler, ABC):
|
class BaseTracer(BaseCallbackHandler, ABC):
|
||||||
"""Base interface for tracers."""
|
"""Base interface for tracers."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
_schema_format: Literal["original", "streaming_events"] = "original",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the tracer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_schema_format: Primarily changes how the inputs and outputs are
|
||||||
|
handled. For internal use only. This API will change.
|
||||||
|
- 'original' is the format used by all current tracers.
|
||||||
|
This format is slightly inconsistent with respect to inputs
|
||||||
|
and outputs.
|
||||||
|
- 'streaming_events' is used for supporting streaming events,
|
||||||
|
for internal usage. It will likely change in the future, or
|
||||||
|
be deprecated entirely in favor of a dedicated async tracer
|
||||||
|
for streaming events.
|
||||||
|
kwargs: Additional keyword arguments that will be passed to
|
||||||
|
the super class.
|
||||||
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self._schema_format = _schema_format # For internal use only API will change.
|
||||||
self.run_map: Dict[str, Run] = {}
|
self.run_map: Dict[str, Run] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -134,17 +158,76 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
return parent_run.child_execution_order + 1
|
return parent_run.child_execution_order + 1
|
||||||
|
|
||||||
def _get_run(self, run_id: UUID, run_type: Optional[str] = None) -> Run:
|
def _get_run(
|
||||||
|
self, run_id: UUID, run_type: Union[str, Set[str], None] = None
|
||||||
|
) -> Run:
|
||||||
try:
|
try:
|
||||||
run = self.run_map[str(run_id)]
|
run = self.run_map[str(run_id)]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise TracerException(f"No indexed run ID {run_id}.") from exc
|
raise TracerException(f"No indexed run ID {run_id}.") from exc
|
||||||
if run_type is not None and run.run_type != run_type:
|
|
||||||
|
if isinstance(run_type, str):
|
||||||
|
run_types: Union[Set[str], None] = {run_type}
|
||||||
|
else:
|
||||||
|
run_types = run_type
|
||||||
|
if run_types is not None and run.run_type not in run_types:
|
||||||
raise TracerException(
|
raise TracerException(
|
||||||
f"Found {run.run_type} run at ID {run_id}, but expected {run_type} run."
|
f"Found {run.run_type} run at ID {run_id}, "
|
||||||
|
f"but expected {run_types} run."
|
||||||
)
|
)
|
||||||
return run
|
return run
|
||||||
|
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Run:
|
||||||
|
"""Start a trace for an LLM run."""
|
||||||
|
if self._schema_format != "streaming_events":
|
||||||
|
# Please keep this un-implemented for backwards compatibility.
|
||||||
|
# When it's unimplemented old tracers that use the "original" format
|
||||||
|
# fallback on the on_llm_start method implementation if they
|
||||||
|
# find that the on_chat_model_start method is not implemented.
|
||||||
|
# This can eventually be cleaned up by writing a "modern" tracer
|
||||||
|
# that has all the updated schema changes corresponding to
|
||||||
|
# the "streaming_events" format.
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Chat model tracing is not supported in "
|
||||||
|
f"for {self._schema_format} format."
|
||||||
|
)
|
||||||
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
|
start_time = datetime.now(timezone.utc)
|
||||||
|
if metadata:
|
||||||
|
kwargs.update({"metadata": metadata})
|
||||||
|
chat_model_run = Run(
|
||||||
|
id=run_id,
|
||||||
|
parent_run_id=parent_run_id,
|
||||||
|
serialized=serialized,
|
||||||
|
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
|
||||||
|
extra=kwargs,
|
||||||
|
events=[{"name": "start", "time": start_time}],
|
||||||
|
start_time=start_time,
|
||||||
|
execution_order=execution_order,
|
||||||
|
child_execution_order=execution_order,
|
||||||
|
# WARNING: This is valid ONLY for streaming_events.
|
||||||
|
# run_type="llm" is what's used by virtually all tracers.
|
||||||
|
# Changing this to "chat_model" may break triggering on_llm_start
|
||||||
|
run_type="chat_model",
|
||||||
|
tags=tags,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
self._start_trace(chat_model_run)
|
||||||
|
self._on_chat_model_start(chat_model_run)
|
||||||
|
return chat_model_run
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
@ -167,6 +250,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
id=run_id,
|
id=run_id,
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
|
# TODO: Figure out how to expose kwargs here
|
||||||
inputs={"prompts": prompts},
|
inputs={"prompts": prompts},
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
events=[{"name": "start", "time": start_time}],
|
events=[{"name": "start", "time": start_time}],
|
||||||
@ -191,7 +275,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||||
llm_run = self._get_run(run_id, run_type="llm")
|
# "chat_model" is only used for the experimental new streaming_events format.
|
||||||
|
# This change should not affect any existing tracers.
|
||||||
|
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||||
event_kwargs: Dict[str, Any] = {"token": token}
|
event_kwargs: Dict[str, Any] = {"token": token}
|
||||||
if chunk:
|
if chunk:
|
||||||
event_kwargs["chunk"] = chunk
|
event_kwargs["chunk"] = chunk
|
||||||
@ -238,7 +324,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||||
"""End a trace for an LLM run."""
|
"""End a trace for an LLM run."""
|
||||||
llm_run = self._get_run(run_id, run_type="llm")
|
# "chat_model" is only used for the experimental new streaming_events format.
|
||||||
|
# This change should not affect any existing tracers.
|
||||||
|
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||||
llm_run.outputs = response.dict()
|
llm_run.outputs = response.dict()
|
||||||
for i, generations in enumerate(response.generations):
|
for i, generations in enumerate(response.generations):
|
||||||
for j, generation in enumerate(generations):
|
for j, generation in enumerate(generations):
|
||||||
@ -261,7 +349,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Handle an error for an LLM run."""
|
"""Handle an error for an LLM run."""
|
||||||
llm_run = self._get_run(run_id, run_type="llm")
|
# "chat_model" is only used for the experimental new streaming_events format.
|
||||||
|
# This change should not affect any existing tracers.
|
||||||
|
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||||
llm_run.error = self._get_stacktrace(error)
|
llm_run.error = self._get_stacktrace(error)
|
||||||
llm_run.end_time = datetime.now(timezone.utc)
|
llm_run.end_time = datetime.now(timezone.utc)
|
||||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||||
@ -292,7 +382,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
id=run_id,
|
id=run_id,
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
inputs=inputs if isinstance(inputs, dict) else {"input": inputs},
|
inputs=self._get_chain_inputs(inputs),
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
events=[{"name": "start", "time": start_time}],
|
events=[{"name": "start", "time": start_time}],
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
@ -307,6 +397,28 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
self._on_chain_start(chain_run)
|
self._on_chain_start(chain_run)
|
||||||
return chain_run
|
return chain_run
|
||||||
|
|
||||||
|
def _get_chain_inputs(self, inputs: Any) -> Any:
|
||||||
|
"""Get the inputs for a chain run."""
|
||||||
|
if self._schema_format == "original":
|
||||||
|
return inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||||
|
elif self._schema_format == "streaming_events":
|
||||||
|
return {
|
||||||
|
"input": inputs,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid format: {self._schema_format}")
|
||||||
|
|
||||||
|
def _get_chain_outputs(self, outputs: Any) -> Any:
|
||||||
|
"""Get the outputs for a chain run."""
|
||||||
|
if self._schema_format == "original":
|
||||||
|
return outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||||
|
elif self._schema_format == "streaming_events":
|
||||||
|
return {
|
||||||
|
"output": outputs,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid format: {self._schema_format}")
|
||||||
|
|
||||||
def on_chain_end(
|
def on_chain_end(
|
||||||
self,
|
self,
|
||||||
outputs: Dict[str, Any],
|
outputs: Dict[str, Any],
|
||||||
@ -317,13 +429,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
) -> Run:
|
) -> Run:
|
||||||
"""End a trace for a chain run."""
|
"""End a trace for a chain run."""
|
||||||
chain_run = self._get_run(run_id)
|
chain_run = self._get_run(run_id)
|
||||||
chain_run.outputs = (
|
chain_run.outputs = self._get_chain_outputs(outputs)
|
||||||
outputs if isinstance(outputs, dict) else {"output": outputs}
|
|
||||||
)
|
|
||||||
chain_run.end_time = datetime.now(timezone.utc)
|
chain_run.end_time = datetime.now(timezone.utc)
|
||||||
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
chain_run.inputs = self._get_chain_inputs(inputs)
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
self._on_chain_end(chain_run)
|
self._on_chain_end(chain_run)
|
||||||
return chain_run
|
return chain_run
|
||||||
@ -342,7 +452,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
chain_run.end_time = datetime.now(timezone.utc)
|
chain_run.end_time = datetime.now(timezone.utc)
|
||||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
chain_run.inputs = self._get_chain_inputs(inputs)
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
self._on_chain_error(chain_run)
|
self._on_chain_error(chain_run)
|
||||||
return chain_run
|
return chain_run
|
||||||
@ -357,6 +467,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Start a trace for a tool run."""
|
"""Start a trace for a tool run."""
|
||||||
@ -365,11 +476,20 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
start_time = datetime.now(timezone.utc)
|
start_time = datetime.now(timezone.utc)
|
||||||
if metadata:
|
if metadata:
|
||||||
kwargs.update({"metadata": metadata})
|
kwargs.update({"metadata": metadata})
|
||||||
|
|
||||||
|
if self._schema_format == "original":
|
||||||
|
inputs = {"input": input_str}
|
||||||
|
elif self._schema_format == "streaming_events":
|
||||||
|
inputs = {"input": inputs}
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"Invalid format: {self._schema_format}")
|
||||||
|
|
||||||
tool_run = Run(
|
tool_run = Run(
|
||||||
id=run_id,
|
id=run_id,
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
inputs={"input": input_str},
|
# Wrapping in dict since Run requires a dict object.
|
||||||
|
inputs=inputs,
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
events=[{"name": "start", "time": start_time}],
|
events=[{"name": "start", "time": start_time}],
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -112,7 +112,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Start a trace for an LLM run."""
|
"""Start a trace for an LLM run."""
|
||||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
@ -135,6 +135,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
)
|
)
|
||||||
self._start_trace(chat_model_run)
|
self._start_trace(chat_model_run)
|
||||||
self._on_chat_model_start(chat_model_run)
|
self._on_chat_model_start(chat_model_run)
|
||||||
|
return chat_model_run
|
||||||
|
|
||||||
def _persist_run(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
run_ = run.copy()
|
run_ = run.copy()
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import threading
|
import threading
|
||||||
@ -9,19 +10,24 @@ from typing import (
|
|||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
TypedDict,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
overload,
|
||||||
)
|
)
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import jsonpatch # type: ignore[import]
|
import jsonpatch # type: ignore[import]
|
||||||
from anyio import create_memory_object_stream
|
from anyio import create_memory_object_stream
|
||||||
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
|
from langchain_core.load import dumps
|
||||||
from langchain_core.load.load import load
|
from langchain_core.load.load import load
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||||
|
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
||||||
|
from langchain_core.runnables.utils import Input, Output
|
||||||
from langchain_core.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
@ -46,8 +52,11 @@ class LogEntry(TypedDict):
|
|||||||
"""List of LLM tokens streamed by this run, if applicable."""
|
"""List of LLM tokens streamed by this run, if applicable."""
|
||||||
streamed_output: List[Any]
|
streamed_output: List[Any]
|
||||||
"""List of output chunks streamed by this run, if available."""
|
"""List of output chunks streamed by this run, if available."""
|
||||||
|
inputs: NotRequired[Optional[Any]]
|
||||||
|
"""Inputs to this run. Not available currently via astream_log."""
|
||||||
final_output: Optional[Any]
|
final_output: Optional[Any]
|
||||||
"""Final output of this run.
|
"""Final output of this run.
|
||||||
|
|
||||||
Only available after the run has finished successfully."""
|
Only available after the run has finished successfully."""
|
||||||
end_time: Optional[str]
|
end_time: Optional[str]
|
||||||
"""ISO-8601 timestamp of when the run ended.
|
"""ISO-8601 timestamp of when the run ended.
|
||||||
@ -65,6 +74,14 @@ class RunState(TypedDict):
|
|||||||
"""Final output of the run, usually the result of aggregating (`+`) streamed_output.
|
"""Final output of the run, usually the result of aggregating (`+`) streamed_output.
|
||||||
Updated throughout the run when supported by the Runnable."""
|
Updated throughout the run when supported by the Runnable."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""Name of the object being run."""
|
||||||
|
type: str
|
||||||
|
"""Type of the object being run, eg. prompt, chain, llm, etc."""
|
||||||
|
|
||||||
|
# Do we want tags/metadata on the root run? Client kinda knows it in most situations
|
||||||
|
# tags: List[str]
|
||||||
|
|
||||||
logs: Dict[str, LogEntry]
|
logs: Dict[str, LogEntry]
|
||||||
"""Map of run names to sub-runs. If filters were supplied, this list will
|
"""Map of run names to sub-runs. If filters were supplied, this list will
|
||||||
contain only the runs that matched the filters."""
|
contain only the runs that matched the filters."""
|
||||||
@ -128,6 +145,15 @@ class RunLog(RunLogPatch):
|
|||||||
|
|
||||||
return f"RunLog({pformat(self.state)})"
|
return f"RunLog({pformat(self.state)})"
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
# First compare that the state is the same
|
||||||
|
if not isinstance(other, RunLog):
|
||||||
|
return False
|
||||||
|
if self.state != other.state:
|
||||||
|
return False
|
||||||
|
# Then compare that the ops are the same
|
||||||
|
return super().__eq__(other)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -145,8 +171,36 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
exclude_names: Optional[Sequence[str]] = None,
|
exclude_names: Optional[Sequence[str]] = None,
|
||||||
exclude_types: Optional[Sequence[str]] = None,
|
exclude_types: Optional[Sequence[str]] = None,
|
||||||
exclude_tags: Optional[Sequence[str]] = None,
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
|
# Schema format is for internal use only.
|
||||||
|
_schema_format: Literal["original", "streaming_events"] = "streaming_events",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
"""A tracer that streams run logs to a stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_close: Whether to close the stream when the root run finishes.
|
||||||
|
include_names: Only include runs from Runnables with matching names.
|
||||||
|
include_types: Only include runs from Runnables with matching types.
|
||||||
|
include_tags: Only include runs from Runnables with matching tags.
|
||||||
|
exclude_names: Exclude runs from Runnables with matching names.
|
||||||
|
exclude_types: Exclude runs from Runnables with matching types.
|
||||||
|
exclude_tags: Exclude runs from Runnables with matching tags.
|
||||||
|
_schema_format: Primarily changes how the inputs and outputs are
|
||||||
|
handled.
|
||||||
|
**For internal use only. This API will change.**
|
||||||
|
- 'original' is the format used by all current tracers.
|
||||||
|
This format is slightly inconsistent with respect to inputs
|
||||||
|
and outputs.
|
||||||
|
- 'streaming_events' is used for supporting streaming events,
|
||||||
|
for internal usage. It will likely change in the future, or
|
||||||
|
be deprecated entirely in favor of a dedicated async tracer
|
||||||
|
for streaming events.
|
||||||
|
"""
|
||||||
|
if _schema_format not in {"original", "streaming_events"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid schema format: {_schema_format}. "
|
||||||
|
f"Expected one of 'original', 'streaming_events'."
|
||||||
|
)
|
||||||
|
super().__init__(_schema_format=_schema_format)
|
||||||
|
|
||||||
self.auto_close = auto_close
|
self.auto_close = auto_close
|
||||||
self.include_names = include_names
|
self.include_names = include_names
|
||||||
@ -241,6 +295,8 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
streamed_output=[],
|
streamed_output=[],
|
||||||
final_output=None,
|
final_output=None,
|
||||||
logs={},
|
logs={},
|
||||||
|
name=run.name,
|
||||||
|
type=run.run_type,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -257,13 +313,7 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
run.name if count == 1 else f"{run.name}:{count}"
|
run.name if count == 1 else f"{run.name}:{count}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the run to the stream
|
entry = LogEntry(
|
||||||
self.send_stream.send_nowait(
|
|
||||||
RunLogPatch(
|
|
||||||
{
|
|
||||||
"op": "add",
|
|
||||||
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
|
|
||||||
"value": LogEntry(
|
|
||||||
id=str(run.id),
|
id=str(run.id),
|
||||||
name=run.name,
|
name=run.name,
|
||||||
type=run.run_type,
|
type=run.run_type,
|
||||||
@ -274,7 +324,19 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
streamed_output_str=[],
|
streamed_output_str=[],
|
||||||
final_output=None,
|
final_output=None,
|
||||||
end_time=None,
|
end_time=None,
|
||||||
),
|
)
|
||||||
|
|
||||||
|
if self._schema_format == "streaming_events":
|
||||||
|
# If using streaming events let's add inputs as well
|
||||||
|
entry["inputs"] = _get_standardized_inputs(run, self._schema_format)
|
||||||
|
|
||||||
|
# Add the run to the stream
|
||||||
|
self.send_stream.send_nowait(
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
|
||||||
|
"value": entry,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -287,13 +349,28 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
if index is None:
|
if index is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.send_stream.send_nowait(
|
ops = []
|
||||||
RunLogPatch(
|
|
||||||
|
if self._schema_format == "streaming_events":
|
||||||
|
ops.append(
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": f"/logs/{index}/inputs",
|
||||||
|
"value": _get_standardized_inputs(run, self._schema_format),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.extend(
|
||||||
|
[
|
||||||
|
# Replace 'inputs' with final inputs
|
||||||
|
# This is needed because in many cases the inputs are not
|
||||||
|
# known until after the run is finished and the entire
|
||||||
|
# input stream has been processed by the runnable.
|
||||||
{
|
{
|
||||||
"op": "add",
|
"op": "add",
|
||||||
"path": f"/logs/{index}/final_output",
|
"path": f"/logs/{index}/final_output",
|
||||||
# to undo the dumpd done by some runnables / tracer / etc
|
# to undo the dumpd done by some runnables / tracer / etc
|
||||||
"value": load(run.outputs),
|
"value": _get_standardized_outputs(run, self._schema_format),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"op": "add",
|
"op": "add",
|
||||||
@ -302,8 +379,10 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
if run.end_time is not None
|
if run.end_time is not None
|
||||||
else None,
|
else None,
|
||||||
},
|
},
|
||||||
|
]
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
self.send_stream.send_nowait(RunLogPatch(*ops))
|
||||||
finally:
|
finally:
|
||||||
if run.id == self.root_id:
|
if run.id == self.root_id:
|
||||||
if self.auto_close:
|
if self.auto_close:
|
||||||
@ -337,3 +416,197 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_standardized_inputs(
|
||||||
|
run: Run, schema_format: Literal["original", "streaming_events"]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Extract standardized inputs from a run.
|
||||||
|
|
||||||
|
Standardizes the inputs based on the type of the runnable used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run: Run object
|
||||||
|
schema_format: The schema format to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Valid inputs are only dict. By conventions, inputs always represented
|
||||||
|
invocation using named arguments.
|
||||||
|
A None means that the input is not yet known!
|
||||||
|
"""
|
||||||
|
if schema_format == "original":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Do not assign inputs with original schema drop the key for now."
|
||||||
|
"When inputs are added to astream_log they should be added with "
|
||||||
|
"standardized schema for streaming events."
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = load(run.inputs)
|
||||||
|
|
||||||
|
if run.run_type in {"retriever", "llm", "chat_model"}:
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
# new style chains
|
||||||
|
# These nest an additional 'input' key inside the 'inputs' to make sure
|
||||||
|
# the input is always a dict. We need to unpack and user the inner value.
|
||||||
|
inputs = inputs["input"]
|
||||||
|
# We should try to fix this in Runnables and callbacks/tracers
|
||||||
|
# Runnables should be using a None type here not a placeholder
|
||||||
|
# dict.
|
||||||
|
if inputs == {"input": ""}: # Workaround for Runnables not using None
|
||||||
|
# The input is not known, so we don't assign data['input']
|
||||||
|
return None
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def _get_standardized_outputs(
|
||||||
|
run: Run, schema_format: Literal["original", "streaming_events"]
|
||||||
|
) -> Optional[Any]:
|
||||||
|
"""Extract standardized output from a run.
|
||||||
|
|
||||||
|
Standardizes the outputs based on the type of the runnable used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log: The log entry.
|
||||||
|
schema_format: The schema format to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An output if returned, otherwise a None
|
||||||
|
"""
|
||||||
|
outputs = load(run.outputs)
|
||||||
|
if schema_format == "original":
|
||||||
|
# Return the old schema, without standardizing anything
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
if run.run_type in {"retriever", "llm", "chat_model"}:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
if isinstance(outputs, dict):
|
||||||
|
return outputs.get("output", None)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def _astream_log_implementation(
|
||||||
|
runnable: Runnable[Input, Output],
|
||||||
|
input: Any,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stream: LogStreamCallbackHandler,
|
||||||
|
diff: Literal[True] = True,
|
||||||
|
with_streamed_output_list: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[RunLogPatch]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def _astream_log_implementation(
|
||||||
|
runnable: Runnable[Input, Output],
|
||||||
|
input: Any,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stream: LogStreamCallbackHandler,
|
||||||
|
diff: Literal[False],
|
||||||
|
with_streamed_output_list: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[RunLog]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
async def _astream_log_implementation(
|
||||||
|
runnable: Runnable[Input, Output],
|
||||||
|
input: Any,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stream: LogStreamCallbackHandler,
|
||||||
|
diff: bool = True,
|
||||||
|
with_streamed_output_list: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
||||||
|
"""Implementation of astream_log for a given runnable.
|
||||||
|
|
||||||
|
The implementation has been factored out (at least temporarily) as both
|
||||||
|
astream_log and astream_events relies on it.
|
||||||
|
"""
|
||||||
|
import jsonpatch # type: ignore[import]
|
||||||
|
|
||||||
|
from langchain_core.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain_core.tracers.log_stream import (
|
||||||
|
RunLog,
|
||||||
|
RunLogPatch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign the stream handler to the config
|
||||||
|
config = ensure_config(config)
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks is None:
|
||||||
|
config["callbacks"] = [stream]
|
||||||
|
elif isinstance(callbacks, list):
|
||||||
|
config["callbacks"] = callbacks + [stream]
|
||||||
|
elif isinstance(callbacks, BaseCallbackManager):
|
||||||
|
callbacks = callbacks.copy()
|
||||||
|
callbacks.add_handler(stream, inherit=True)
|
||||||
|
config["callbacks"] = callbacks
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected type for callbacks: {callbacks}."
|
||||||
|
"Expected None, list or AsyncCallbackManager."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the runnable in streaming mode,
|
||||||
|
# add each chunk to the output stream
|
||||||
|
async def consume_astream() -> None:
|
||||||
|
try:
|
||||||
|
prev_final_output: Optional[Output] = None
|
||||||
|
final_output: Optional[Output] = None
|
||||||
|
|
||||||
|
async for chunk in runnable.astream(input, config, **kwargs):
|
||||||
|
prev_final_output = final_output
|
||||||
|
if final_output is None:
|
||||||
|
final_output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_output = final_output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
final_output = chunk
|
||||||
|
patches: List[Dict[str, Any]] = []
|
||||||
|
if with_streamed_output_list:
|
||||||
|
patches.append(
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": "/streamed_output/-",
|
||||||
|
# chunk cannot be shared between
|
||||||
|
# streamed_output and final_output
|
||||||
|
# otherwise jsonpatch.apply will
|
||||||
|
# modify both
|
||||||
|
"value": copy.deepcopy(chunk),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for op in jsonpatch.JsonPatch.from_diff(
|
||||||
|
prev_final_output, final_output, dumps=dumps
|
||||||
|
):
|
||||||
|
patches.append({**op, "path": f"/final_output{op['path']}"})
|
||||||
|
await stream.send_stream.send(RunLogPatch(*patches))
|
||||||
|
finally:
|
||||||
|
await stream.send_stream.aclose()
|
||||||
|
|
||||||
|
# Start the runnable in a task, so we can start consuming output
|
||||||
|
task = asyncio.create_task(consume_astream())
|
||||||
|
try:
|
||||||
|
# Yield each chunk from the output stream
|
||||||
|
if diff:
|
||||||
|
async for log in stream:
|
||||||
|
yield log
|
||||||
|
else:
|
||||||
|
state = RunLog(state=None) # type: ignore[arg-type]
|
||||||
|
async for log in stream:
|
||||||
|
state = state + log
|
||||||
|
yield state
|
||||||
|
finally:
|
||||||
|
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
@ -1647,6 +1647,8 @@ async def test_prompt() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
"type": "prompt",
|
||||||
|
"name": "ChatPromptTemplate",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2095,6 +2097,8 @@ async def test_prompt_with_llm(
|
|||||||
"logs": {},
|
"logs": {},
|
||||||
"final_output": None,
|
"final_output": None,
|
||||||
"streamed_output": [],
|
"streamed_output": [],
|
||||||
|
"name": "RunnableSequence",
|
||||||
|
"type": "chain",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
@ -2297,6 +2301,8 @@ async def test_prompt_with_llm_parser(
|
|||||||
"logs": {},
|
"logs": {},
|
||||||
"final_output": None,
|
"final_output": None,
|
||||||
"streamed_output": [],
|
"streamed_output": [],
|
||||||
|
"name": "RunnableSequence",
|
||||||
|
"type": "chain",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
@ -2508,7 +2514,13 @@ async def test_stream_log_lists() -> None:
|
|||||||
{
|
{
|
||||||
"op": "replace",
|
"op": "replace",
|
||||||
"path": "",
|
"path": "",
|
||||||
"value": {"final_output": None, "logs": {}, "streamed_output": []},
|
"value": {
|
||||||
|
"final_output": None,
|
||||||
|
"logs": {},
|
||||||
|
"streamed_output": [],
|
||||||
|
"name": "list_producer",
|
||||||
|
"type": "chain",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
RunLogPatch(
|
RunLogPatch(
|
||||||
@ -2536,12 +2548,14 @@ async def test_stream_log_lists() -> None:
|
|||||||
assert state.state == {
|
assert state.state == {
|
||||||
"final_output": {"alist": ["0", "1", "2", "3"]},
|
"final_output": {"alist": ["0", "1", "2", "3"]},
|
||||||
"logs": {},
|
"logs": {},
|
||||||
|
"name": "list_producer",
|
||||||
"streamed_output": [
|
"streamed_output": [
|
||||||
{"alist": ["0"]},
|
{"alist": ["0"]},
|
||||||
{"alist": ["1"]},
|
{"alist": ["1"]},
|
||||||
{"alist": ["2"]},
|
{"alist": ["2"]},
|
||||||
{"alist": ["3"]},
|
{"alist": ["3"]},
|
||||||
],
|
],
|
||||||
|
"type": "chain",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -5139,4 +5153,6 @@ async def test_astream_log_deep_copies() -> None:
|
|||||||
"final_output": 2,
|
"final_output": 2,
|
||||||
"logs": {},
|
"logs": {},
|
||||||
"streamed_output": [2],
|
"streamed_output": [2],
|
||||||
|
"name": "add_one",
|
||||||
|
"type": "chain",
|
||||||
}
|
}
|
||||||
|
1065
libs/core/tests/unit_tests/runnables/test_runnable_events.py
Normal file
1065
libs/core/tests/unit_tests/runnables/test_runnable_events.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user