mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Add Runnable.astream_log()
(#10374)
--------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
a1ade48e8f
commit
fcb5aba9f0
@ -1,7 +1,7 @@
|
|||||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from tenacity import RetryCallState
|
from tenacity import RetryCallState
|
||||||
@ -502,6 +502,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Run on retriever error."""
|
"""Run on retriever error."""
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="BaseCallbackManager")
|
||||||
|
|
||||||
|
|
||||||
class BaseCallbackManager(CallbackManagerMixin):
|
class BaseCallbackManager(CallbackManagerMixin):
|
||||||
"""Base callback manager that handles callbacks from LangChain."""
|
"""Base callback manager that handles callbacks from LangChain."""
|
||||||
|
|
||||||
@ -527,6 +530,18 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
self.metadata = metadata or {}
|
self.metadata = metadata or {}
|
||||||
self.inheritable_metadata = inheritable_metadata or {}
|
self.inheritable_metadata = inheritable_metadata or {}
|
||||||
|
|
||||||
|
def copy(self: T) -> T:
|
||||||
|
"""Copy the callback manager."""
|
||||||
|
return self.__class__(
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_async(self) -> bool:
|
def is_async(self) -> bool:
|
||||||
"""Whether the callback manager is async."""
|
"""Whether the callback manager is async."""
|
||||||
|
@ -58,6 +58,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
else:
|
else:
|
||||||
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
|
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
|
||||||
self.run_map[str(run.id)] = run
|
self.run_map[str(run.id)] = run
|
||||||
|
self._on_run_create(run)
|
||||||
|
|
||||||
def _end_trace(self, run: Run) -> None:
|
def _end_trace(self, run: Run) -> None:
|
||||||
"""End a trace for a run."""
|
"""End a trace for a run."""
|
||||||
@ -74,6 +75,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
):
|
):
|
||||||
parent_run.child_execution_order = run.child_execution_order
|
parent_run.child_execution_order = run.child_execution_order
|
||||||
self.run_map.pop(str(run.id))
|
self.run_map.pop(str(run.id))
|
||||||
|
self._on_run_update(run)
|
||||||
|
|
||||||
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
||||||
"""Get the execution order for a run."""
|
"""Get the execution order for a run."""
|
||||||
@ -101,7 +103,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Start a trace for an LLM run."""
|
"""Start a trace for an LLM run."""
|
||||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
@ -123,6 +125,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
)
|
)
|
||||||
self._start_trace(llm_run)
|
self._start_trace(llm_run)
|
||||||
self._on_llm_start(llm_run)
|
self._on_llm_start(llm_run)
|
||||||
|
return llm_run
|
||||||
|
|
||||||
def on_llm_new_token(
|
def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
@ -132,7 +135,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_llm_new_token callback.")
|
raise TracerException("No run_id provided for on_llm_new_token callback.")
|
||||||
@ -151,6 +154,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
"kwargs": event_kwargs,
|
"kwargs": event_kwargs,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self._on_llm_new_token(llm_run, token, chunk)
|
||||||
|
return llm_run
|
||||||
|
|
||||||
def on_retry(
|
def on_retry(
|
||||||
self,
|
self,
|
||||||
@ -158,7 +163,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_retry callback.")
|
raise TracerException("No run_id provided for on_retry callback.")
|
||||||
run_id_ = str(run_id)
|
run_id_ = str(run_id)
|
||||||
@ -186,8 +191,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
"kwargs": retry_d,
|
"kwargs": retry_d,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
return llm_run
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||||
"""End a trace for an LLM run."""
|
"""End a trace for an LLM run."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_llm_end callback.")
|
raise TracerException("No run_id provided for on_llm_end callback.")
|
||||||
@ -208,6 +214,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
llm_run.events.append({"name": "end", "time": llm_run.end_time})
|
llm_run.events.append({"name": "end", "time": llm_run.end_time})
|
||||||
self._end_trace(llm_run)
|
self._end_trace(llm_run)
|
||||||
self._on_llm_end(llm_run)
|
self._on_llm_end(llm_run)
|
||||||
|
return llm_run
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self,
|
self,
|
||||||
@ -215,7 +222,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Handle an error for an LLM run."""
|
"""Handle an error for an LLM run."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_llm_error callback.")
|
raise TracerException("No run_id provided for on_llm_error callback.")
|
||||||
@ -229,6 +236,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||||
self._end_trace(llm_run)
|
self._end_trace(llm_run)
|
||||||
self._on_chain_error(llm_run)
|
self._on_chain_error(llm_run)
|
||||||
|
return llm_run
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self,
|
self,
|
||||||
@ -242,7 +250,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Start a trace for a chain run."""
|
"""Start a trace for a chain run."""
|
||||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
@ -266,6 +274,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
)
|
)
|
||||||
self._start_trace(chain_run)
|
self._start_trace(chain_run)
|
||||||
self._on_chain_start(chain_run)
|
self._on_chain_start(chain_run)
|
||||||
|
return chain_run
|
||||||
|
|
||||||
def on_chain_end(
|
def on_chain_end(
|
||||||
self,
|
self,
|
||||||
@ -274,7 +283,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
inputs: Optional[Dict[str, Any]] = None,
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""End a trace for a chain run."""
|
"""End a trace for a chain run."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_chain_end callback.")
|
raise TracerException("No run_id provided for on_chain_end callback.")
|
||||||
@ -291,6 +300,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
self._on_chain_end(chain_run)
|
self._on_chain_end(chain_run)
|
||||||
|
return chain_run
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
@ -299,7 +309,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
inputs: Optional[Dict[str, Any]] = None,
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Handle an error for a chain run."""
|
"""Handle an error for a chain run."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_chain_error callback.")
|
raise TracerException("No run_id provided for on_chain_error callback.")
|
||||||
@ -314,6 +324,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
self._on_chain_error(chain_run)
|
self._on_chain_error(chain_run)
|
||||||
|
return chain_run
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
@ -325,7 +336,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Start a trace for a tool run."""
|
"""Start a trace for a tool run."""
|
||||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
@ -348,8 +359,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
)
|
)
|
||||||
self._start_trace(tool_run)
|
self._start_trace(tool_run)
|
||||||
self._on_tool_start(tool_run)
|
self._on_tool_start(tool_run)
|
||||||
|
return tool_run
|
||||||
|
|
||||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||||
"""End a trace for a tool run."""
|
"""End a trace for a tool run."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_tool_end callback.")
|
raise TracerException("No run_id provided for on_tool_end callback.")
|
||||||
@ -362,6 +374,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
tool_run.events.append({"name": "end", "time": tool_run.end_time})
|
tool_run.events.append({"name": "end", "time": tool_run.end_time})
|
||||||
self._end_trace(tool_run)
|
self._end_trace(tool_run)
|
||||||
self._on_tool_end(tool_run)
|
self._on_tool_end(tool_run)
|
||||||
|
return tool_run
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self,
|
self,
|
||||||
@ -369,7 +382,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Handle an error for a tool run."""
|
"""Handle an error for a tool run."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_tool_error callback.")
|
raise TracerException("No run_id provided for on_tool_error callback.")
|
||||||
@ -382,6 +395,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
||||||
self._end_trace(tool_run)
|
self._end_trace(tool_run)
|
||||||
self._on_tool_error(tool_run)
|
self._on_tool_error(tool_run)
|
||||||
|
return tool_run
|
||||||
|
|
||||||
def on_retriever_start(
|
def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
@ -393,7 +407,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Run when Retriever starts running."""
|
"""Run when Retriever starts running."""
|
||||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
@ -417,6 +431,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
)
|
)
|
||||||
self._start_trace(retrieval_run)
|
self._start_trace(retrieval_run)
|
||||||
self._on_retriever_start(retrieval_run)
|
self._on_retriever_start(retrieval_run)
|
||||||
|
return retrieval_run
|
||||||
|
|
||||||
def on_retriever_error(
|
def on_retriever_error(
|
||||||
self,
|
self,
|
||||||
@ -424,7 +439,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Run when Retriever errors."""
|
"""Run when Retriever errors."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_retriever_error callback.")
|
raise TracerException("No run_id provided for on_retriever_error callback.")
|
||||||
@ -437,10 +452,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
||||||
self._end_trace(retrieval_run)
|
self._end_trace(retrieval_run)
|
||||||
self._on_retriever_error(retrieval_run)
|
self._on_retriever_error(retrieval_run)
|
||||||
|
return retrieval_run
|
||||||
|
|
||||||
def on_retriever_end(
|
def on_retriever_end(
|
||||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||||
) -> None:
|
) -> Run:
|
||||||
"""Run when Retriever ends running."""
|
"""Run when Retriever ends running."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_retriever_end callback.")
|
raise TracerException("No run_id provided for on_retriever_end callback.")
|
||||||
@ -452,6 +468,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
|
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
|
||||||
self._end_trace(retrieval_run)
|
self._end_trace(retrieval_run)
|
||||||
self._on_retriever_end(retrieval_run)
|
self._on_retriever_end(retrieval_run)
|
||||||
|
return retrieval_run
|
||||||
|
|
||||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||||
"""Deepcopy the tracer."""
|
"""Deepcopy the tracer."""
|
||||||
@ -461,9 +478,23 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
"""Copy the tracer."""
|
"""Copy the tracer."""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _on_run_create(self, run: Run) -> None:
|
||||||
|
"""Process a run upon creation."""
|
||||||
|
|
||||||
|
def _on_run_update(self, run: Run) -> None:
|
||||||
|
"""Process a run upon update."""
|
||||||
|
|
||||||
def _on_llm_start(self, run: Run) -> None:
|
def _on_llm_start(self, run: Run) -> None:
|
||||||
"""Process the LLM Run upon start."""
|
"""Process the LLM Run upon start."""
|
||||||
|
|
||||||
|
def _on_llm_new_token(
|
||||||
|
self,
|
||||||
|
run: Run,
|
||||||
|
token: str,
|
||||||
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||||
|
) -> None:
|
||||||
|
"""Process new LLM token."""
|
||||||
|
|
||||||
def _on_llm_end(self, run: Run) -> None:
|
def _on_llm_end(self, run: Run) -> None:
|
||||||
"""Process the LLM Run."""
|
"""Process the LLM Run."""
|
||||||
|
|
||||||
|
289
libs/langchain/langchain/callbacks/tracers/log_stream.py
Normal file
289
libs/langchain/langchain/callbacks/tracers/log_stream.py
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
TypedDict,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import jsonpatch
|
||||||
|
from anyio import create_memory_object_stream
|
||||||
|
|
||||||
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||||
|
|
||||||
|
|
||||||
|
class LogEntry(TypedDict):
|
||||||
|
id: str
|
||||||
|
"""ID of the sub-run."""
|
||||||
|
name: str
|
||||||
|
"""Name of the object being run."""
|
||||||
|
type: str
|
||||||
|
"""Type of the object being run, eg. prompt, chain, llm, etc."""
|
||||||
|
tags: List[str]
|
||||||
|
"""List of tags for the run."""
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
"""Key-value pairs of metadata for the run."""
|
||||||
|
start_time: str
|
||||||
|
"""ISO-8601 timestamp of when the run started."""
|
||||||
|
|
||||||
|
streamed_output_str: List[str]
|
||||||
|
"""List of LLM tokens streamed by this run, if applicable."""
|
||||||
|
final_output: Optional[Any]
|
||||||
|
"""Final output of this run.
|
||||||
|
Only available after the run has finished successfully."""
|
||||||
|
end_time: Optional[str]
|
||||||
|
"""ISO-8601 timestamp of when the run ended.
|
||||||
|
Only available after the run has finished."""
|
||||||
|
|
||||||
|
|
||||||
|
class RunState(TypedDict):
|
||||||
|
id: str
|
||||||
|
"""ID of the run."""
|
||||||
|
streamed_output: List[Any]
|
||||||
|
"""List of output chunks streamed by Runnable.stream()"""
|
||||||
|
final_output: Optional[Any]
|
||||||
|
"""Final output of the run, usually the result of aggregating streamed_output.
|
||||||
|
Only available after the run has finished successfully."""
|
||||||
|
|
||||||
|
logs: list[LogEntry]
|
||||||
|
"""List of sub-runs contained in this run, if any, in the order they were started.
|
||||||
|
If filters were supplied, this list will contain only the runs that matched the
|
||||||
|
filters."""
|
||||||
|
|
||||||
|
|
||||||
|
class RunLogPatch:
|
||||||
|
ops: List[Dict[str, Any]]
|
||||||
|
"""List of jsonpatch operations, which describe how to create the run state
|
||||||
|
from an empty dict. This is the minimal representation of the log, designed to
|
||||||
|
be serialized as JSON and sent over the wire to reconstruct the log on the other
|
||||||
|
side. Reconstruction of the state can be done with any jsonpatch-compliant library,
|
||||||
|
see https://jsonpatch.com for more information."""
|
||||||
|
|
||||||
|
def __init__(self, *ops: Dict[str, Any]) -> None:
|
||||||
|
self.ops = list(ops)
|
||||||
|
|
||||||
|
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch:
|
||||||
|
if type(other) == RunLogPatch:
|
||||||
|
ops = self.ops + other.ops
|
||||||
|
state = jsonpatch.apply_patch(None, ops)
|
||||||
|
return RunLog(*ops, state=state)
|
||||||
|
|
||||||
|
raise TypeError(
|
||||||
|
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
from pprint import pformat
|
||||||
|
|
||||||
|
return f"RunLogPatch(ops={pformat(self.ops)})"
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
return isinstance(other, RunLogPatch) and self.ops == other.ops
|
||||||
|
|
||||||
|
|
||||||
|
class RunLog(RunLogPatch):
|
||||||
|
state: RunState
|
||||||
|
"""Current state of the log, obtained from applying all ops in sequence."""
|
||||||
|
|
||||||
|
def __init__(self, *ops: Dict[str, Any], state: RunState) -> None:
|
||||||
|
super().__init__(*ops)
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch:
|
||||||
|
if type(other) == RunLogPatch:
|
||||||
|
ops = self.ops + other.ops
|
||||||
|
state = jsonpatch.apply_patch(self.state, other.ops)
|
||||||
|
return RunLog(*ops, state=state)
|
||||||
|
|
||||||
|
raise TypeError(
|
||||||
|
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
from pprint import pformat
|
||||||
|
|
||||||
|
return f"RunLog(state={pformat(self.state)})"
|
||||||
|
|
||||||
|
|
||||||
|
class LogStreamCallbackHandler(BaseTracer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
auto_close: bool = True,
|
||||||
|
include_names: Optional[Sequence[str]] = None,
|
||||||
|
include_types: Optional[Sequence[str]] = None,
|
||||||
|
include_tags: Optional[Sequence[str]] = None,
|
||||||
|
exclude_names: Optional[Sequence[str]] = None,
|
||||||
|
exclude_types: Optional[Sequence[str]] = None,
|
||||||
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.auto_close = auto_close
|
||||||
|
self.include_names = include_names
|
||||||
|
self.include_types = include_types
|
||||||
|
self.include_tags = include_tags
|
||||||
|
self.exclude_names = exclude_names
|
||||||
|
self.exclude_types = exclude_types
|
||||||
|
self.exclude_tags = exclude_tags
|
||||||
|
|
||||||
|
send_stream, receive_stream = create_memory_object_stream(
|
||||||
|
math.inf, item_type=RunLogPatch
|
||||||
|
)
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.send_stream = send_stream
|
||||||
|
self.receive_stream = receive_stream
|
||||||
|
self._index_map: Dict[UUID, int] = {}
|
||||||
|
|
||||||
|
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
|
||||||
|
return self.receive_stream.__aiter__()
|
||||||
|
|
||||||
|
def include_run(self, run: Run) -> bool:
|
||||||
|
if run.parent_run_id is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
run_tags = run.tags or []
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.include_names is None
|
||||||
|
and self.include_types is None
|
||||||
|
and self.include_tags is None
|
||||||
|
):
|
||||||
|
include = True
|
||||||
|
else:
|
||||||
|
include = False
|
||||||
|
|
||||||
|
if self.include_names is not None:
|
||||||
|
include = include or run.name in self.include_names
|
||||||
|
if self.include_types is not None:
|
||||||
|
include = include or run.run_type in self.include_types
|
||||||
|
if self.include_tags is not None:
|
||||||
|
include = include or any(tag in self.include_tags for tag in run_tags)
|
||||||
|
|
||||||
|
if self.exclude_names is not None:
|
||||||
|
include = include and run.name not in self.exclude_names
|
||||||
|
if self.exclude_types is not None:
|
||||||
|
include = include and run.run_type not in self.exclude_types
|
||||||
|
if self.exclude_tags is not None:
|
||||||
|
include = include and all(tag not in self.exclude_tags for tag in run_tags)
|
||||||
|
|
||||||
|
return include
|
||||||
|
|
||||||
|
def _persist_run(self, run: Run) -> None:
|
||||||
|
# This is a legacy method only called once for an entire run tree
|
||||||
|
# therefore not useful here
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _on_run_create(self, run: Run) -> None:
|
||||||
|
"""Start a run."""
|
||||||
|
if run.parent_run_id is None:
|
||||||
|
self.send_stream.send_nowait(
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "",
|
||||||
|
"value": RunState(
|
||||||
|
id=run.id,
|
||||||
|
streamed_output=[],
|
||||||
|
final_output=None,
|
||||||
|
logs=[],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.include_run(run):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine previous index, increment by 1
|
||||||
|
with self.lock:
|
||||||
|
self._index_map[run.id] = max(self._index_map.values(), default=-1) + 1
|
||||||
|
|
||||||
|
# Add the run to the stream
|
||||||
|
self.send_stream.send_nowait(
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": f"/logs/{self._index_map[run.id]}",
|
||||||
|
"value": LogEntry(
|
||||||
|
id=str(run.id),
|
||||||
|
name=run.name,
|
||||||
|
type=run.run_type,
|
||||||
|
tags=run.tags or [],
|
||||||
|
metadata=run.extra.get("metadata", {}),
|
||||||
|
start_time=run.start_time.isoformat(timespec="milliseconds"),
|
||||||
|
streamed_output_str=[],
|
||||||
|
final_output=None,
|
||||||
|
end_time=None,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_run_update(self, run: Run) -> None:
|
||||||
|
"""Finish a run."""
|
||||||
|
try:
|
||||||
|
index = self._index_map.get(run.id)
|
||||||
|
|
||||||
|
if index is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.send_stream.send_nowait(
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": f"/logs/{index}/final_output",
|
||||||
|
"value": run.outputs,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": f"/logs/{index}/end_time",
|
||||||
|
"value": run.end_time.isoformat(timespec="milliseconds"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if run.parent_run_id is None:
|
||||||
|
self.send_stream.send_nowait(
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/final_output",
|
||||||
|
"value": run.outputs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.auto_close:
|
||||||
|
self.send_stream.close()
|
||||||
|
|
||||||
|
def _on_llm_new_token(
|
||||||
|
self,
|
||||||
|
run: Run,
|
||||||
|
token: str,
|
||||||
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||||
|
) -> None:
|
||||||
|
"""Process new LLM token."""
|
||||||
|
index = self._index_map.get(run.id)
|
||||||
|
|
||||||
|
if index is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.send_stream.send_nowait(
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": f"/logs/{index}/streamed_output_str/-",
|
||||||
|
"value": token,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
@ -34,6 +34,8 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.callbacks.tracers.log_stream import LogStreamCallbackHandler, RunLogPatch
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field
|
||||||
@ -190,6 +192,89 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
yield await self.ainvoke(input, config, **kwargs)
|
yield await self.ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
|
async def astream_log(
|
||||||
|
self,
|
||||||
|
input: Any,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
include_names: Optional[Sequence[str]] = None,
|
||||||
|
include_types: Optional[Sequence[str]] = None,
|
||||||
|
include_tags: Optional[Sequence[str]] = None,
|
||||||
|
exclude_names: Optional[Sequence[str]] = None,
|
||||||
|
exclude_types: Optional[Sequence[str]] = None,
|
||||||
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[RunLogPatch]:
|
||||||
|
"""
|
||||||
|
Stream all output from a runnable, as reported to the callback system.
|
||||||
|
This includes all inner runs of LLMs, Retrievers, Tools, etc.
|
||||||
|
|
||||||
|
Output is streamed as Log objects, which include a list of
|
||||||
|
jsonpatch ops that describe how the state of the run has changed in each
|
||||||
|
step, and the final state of the run.
|
||||||
|
|
||||||
|
The jsonpatch ops can be applied in order to construct state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a stream handler that will emit Log objects
|
||||||
|
stream = LogStreamCallbackHandler(
|
||||||
|
auto_close=False,
|
||||||
|
include_names=include_names,
|
||||||
|
include_types=include_types,
|
||||||
|
include_tags=include_tags,
|
||||||
|
exclude_names=exclude_names,
|
||||||
|
exclude_types=exclude_types,
|
||||||
|
exclude_tags=exclude_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign the stream handler to the config
|
||||||
|
config = config or {}
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks is None:
|
||||||
|
config["callbacks"] = [stream]
|
||||||
|
elif isinstance(callbacks, list):
|
||||||
|
config["callbacks"] = callbacks + [stream]
|
||||||
|
elif isinstance(callbacks, BaseCallbackManager):
|
||||||
|
callbacks = callbacks.copy()
|
||||||
|
callbacks.inheritable_handlers.append(stream)
|
||||||
|
config["callbacks"] = callbacks
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected type for callbacks: {callbacks}."
|
||||||
|
"Expected None, list or AsyncCallbackManager."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the runnable in streaming mode,
|
||||||
|
# add each chunk to the output stream
|
||||||
|
async def consume_astream() -> None:
|
||||||
|
try:
|
||||||
|
async for chunk in self.astream(input, config, **kwargs):
|
||||||
|
await stream.send_stream.send(
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": "/streamed_output/-",
|
||||||
|
"value": chunk,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await stream.send_stream.aclose()
|
||||||
|
|
||||||
|
# Start the runnable in a task, so we can start consuming output
|
||||||
|
task = asyncio.create_task(consume_astream())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Yield each chunk from the output stream
|
||||||
|
async for log in stream:
|
||||||
|
yield log
|
||||||
|
finally:
|
||||||
|
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
input: Iterator[Input],
|
||||||
|
16
libs/langchain/poetry.lock
generated
16
libs/langchain/poetry.lock
generated
@ -3610,6 +3610,20 @@ files = [
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
attrs = ">=19.2.0"
|
attrs = ">=19.2.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jsonpatch"
|
||||||
|
version = "1.33"
|
||||||
|
description = "Apply JSON-Patches (RFC 6902)"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
|
||||||
|
files = [
|
||||||
|
{file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"},
|
||||||
|
{file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
jsonpointer = ">=1.9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jsonpointer"
|
name = "jsonpointer"
|
||||||
version = "2.4"
|
version = "2.4"
|
||||||
@ -10608,4 +10622,4 @@ text-helpers = ["chardet"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "11ce1c967a78f79a922b9bbbc1c00541703185e28c63b7a0a02aa5c562c36ee3"
|
content-hash = "3a3749b3d63be94ef11de23ec7ad40cc20cca78fa7352c5ed7d537988ce90a85"
|
||||||
|
@ -129,6 +129,8 @@ markdownify = {version = "^0.11.6", optional = true}
|
|||||||
assemblyai = {version = "^0.17.0", optional = true}
|
assemblyai = {version = "^0.17.0", optional = true}
|
||||||
dashvector = {version = "^1.0.1", optional = true}
|
dashvector = {version = "^1.0.1", optional = true}
|
||||||
sqlite-vss = {version = "^0.1.2", optional = true}
|
sqlite-vss = {version = "^0.1.2", optional = true}
|
||||||
|
anyio = "<4.0"
|
||||||
|
jsonpatch = "^1.33"
|
||||||
timescale-vector = {version = "^0.0.1", optional = true}
|
timescale-vector = {version = "^0.0.1", optional = true}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -9,6 +9,7 @@ from syrupy import SnapshotAssertion
|
|||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks, collect_runs
|
from langchain.callbacks.manager import Callbacks, collect_runs
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
|
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||||
from langchain.chat_models.fake import FakeListChatModel
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
@ -368,6 +369,62 @@ async def test_prompt() -> None:
|
|||||||
part async for part in prompt.astream({"question": "What is your name?"})
|
part async for part in prompt.astream({"question": "What is your name?"})
|
||||||
] == [expected]
|
] == [expected]
|
||||||
|
|
||||||
|
stream_log = [
|
||||||
|
part async for part in prompt.astream_log({"question": "What is your name?"})
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(stream_log[0].ops) == 1
|
||||||
|
assert stream_log[0].ops[0]["op"] == "replace"
|
||||||
|
assert stream_log[0].ops[0]["path"] == ""
|
||||||
|
assert stream_log[0].ops[0]["value"]["logs"] == []
|
||||||
|
assert stream_log[0].ops[0]["value"]["final_output"] is None
|
||||||
|
assert stream_log[0].ops[0]["value"]["streamed_output"] == []
|
||||||
|
assert type(stream_log[0].ops[0]["value"]["id"]) == UUID
|
||||||
|
|
||||||
|
assert stream_log[1:] == [
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/final_output",
|
||||||
|
"value": {
|
||||||
|
"id": ["langchain", "prompts", "chat", "ChatPromptValue"],
|
||||||
|
"kwargs": {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"messages",
|
||||||
|
"SystemMessage",
|
||||||
|
],
|
||||||
|
"kwargs": {"content": "You are a nice " "assistant."},
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"messages",
|
||||||
|
"HumanMessage",
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"additional_kwargs": {},
|
||||||
|
"content": "What is your " "name?",
|
||||||
|
},
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_template_params() -> None:
|
def test_prompt_template_params() -> None:
|
||||||
prompt = ChatPromptTemplate.from_template(
|
prompt = ChatPromptTemplate.from_template(
|
||||||
@ -560,7 +617,7 @@ async def test_prompt_with_llm(
|
|||||||
mocker.stop(prompt_spy)
|
mocker.stop(prompt_spy)
|
||||||
mocker.stop(llm_spy)
|
mocker.stop(llm_spy)
|
||||||
|
|
||||||
# Test stream#
|
# Test stream
|
||||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||||
llm_spy = mocker.spy(llm.__class__, "astream")
|
llm_spy = mocker.spy(llm.__class__, "astream")
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
@ -578,6 +635,136 @@ async def test_prompt_with_llm(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt_spy.reset_mock()
|
||||||
|
llm_spy.reset_mock()
|
||||||
|
stream_log = [
|
||||||
|
part async for part in chain.astream_log({"question": "What is your name?"})
|
||||||
|
]
|
||||||
|
|
||||||
|
# remove ids from logs
|
||||||
|
for part in stream_log:
|
||||||
|
for op in part.ops:
|
||||||
|
if (
|
||||||
|
isinstance(op["value"], dict)
|
||||||
|
and "id" in op["value"]
|
||||||
|
and not isinstance(op["value"]["id"], list) # serialized lc id
|
||||||
|
):
|
||||||
|
del op["value"]["id"]
|
||||||
|
|
||||||
|
assert stream_log == [
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "",
|
||||||
|
"value": {
|
||||||
|
"logs": [],
|
||||||
|
"final_output": None,
|
||||||
|
"streamed_output": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": "/logs/0",
|
||||||
|
"value": {
|
||||||
|
"end_time": None,
|
||||||
|
"final_output": None,
|
||||||
|
"metadata": {},
|
||||||
|
"name": "ChatPromptTemplate",
|
||||||
|
"start_time": "2023-01-01T00:00:00.000",
|
||||||
|
"streamed_output_str": [],
|
||||||
|
"tags": ["seq:step:1"],
|
||||||
|
"type": "prompt",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": "/logs/0/final_output",
|
||||||
|
"value": {
|
||||||
|
"id": ["langchain", "prompts", "chat", "ChatPromptValue"],
|
||||||
|
"kwargs": {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"messages",
|
||||||
|
"SystemMessage",
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"additional_kwargs": {},
|
||||||
|
"content": "You are a nice " "assistant.",
|
||||||
|
},
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"messages",
|
||||||
|
"HumanMessage",
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"additional_kwargs": {},
|
||||||
|
"content": "What is your " "name?",
|
||||||
|
},
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": "/logs/0/end_time",
|
||||||
|
"value": "2023-01-01T00:00:00.000",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": "/logs/1",
|
||||||
|
"value": {
|
||||||
|
"end_time": None,
|
||||||
|
"final_output": None,
|
||||||
|
"metadata": {},
|
||||||
|
"name": "FakeListLLM",
|
||||||
|
"start_time": "2023-01-01T00:00:00.000",
|
||||||
|
"streamed_output_str": [],
|
||||||
|
"tags": ["seq:step:2"],
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": "/logs/1/final_output",
|
||||||
|
"value": {
|
||||||
|
"generations": [[{"generation_info": None, "text": "foo"}]],
|
||||||
|
"llm_output": None,
|
||||||
|
"run": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": "/logs/1/end_time",
|
||||||
|
"value": "2023-01-01T00:00:00.000",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": "foo"}),
|
||||||
|
RunLogPatch(
|
||||||
|
{"op": "replace", "path": "/final_output", "value": {"output": "foo"}}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
@ -1213,6 +1400,74 @@ async def test_map_astream() -> None:
|
|||||||
{"question": "What is your name?"}
|
{"question": "What is your name?"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test astream_log state accumulation
|
||||||
|
|
||||||
|
final_state = None
|
||||||
|
streamed_ops = []
|
||||||
|
async for chunk in chain.astream_log({"question": "What is your name?"}):
|
||||||
|
streamed_ops.extend(chunk.ops)
|
||||||
|
if final_state is None:
|
||||||
|
final_state = chunk
|
||||||
|
else:
|
||||||
|
final_state += chunk
|
||||||
|
final_state = cast(RunLog, final_state)
|
||||||
|
|
||||||
|
assert final_state.state["final_output"] == final_value
|
||||||
|
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
|
||||||
|
assert isinstance(final_state.state["id"], UUID)
|
||||||
|
assert len(final_state.ops) == len(streamed_ops)
|
||||||
|
assert len(final_state.state["logs"]) == 5
|
||||||
|
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate"
|
||||||
|
assert final_state.state["logs"][0]["final_output"] == dumpd(
|
||||||
|
prompt.invoke({"question": "What is your name?"})
|
||||||
|
)
|
||||||
|
assert final_state.state["logs"][1]["name"] == "RunnableMap"
|
||||||
|
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [
|
||||||
|
"FakeListChatModel",
|
||||||
|
"FakeStreamingListLLM",
|
||||||
|
"RunnablePassthrough",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test astream_log with include filters
|
||||||
|
final_state = None
|
||||||
|
async for chunk in chain.astream_log(
|
||||||
|
{"question": "What is your name?"}, include_names=["FakeListChatModel"]
|
||||||
|
):
|
||||||
|
if final_state is None:
|
||||||
|
final_state = chunk
|
||||||
|
else:
|
||||||
|
final_state += chunk
|
||||||
|
final_state = cast(RunLog, final_state)
|
||||||
|
|
||||||
|
assert final_state.state["final_output"] == final_value
|
||||||
|
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
|
||||||
|
assert len(final_state.state["logs"]) == 1
|
||||||
|
assert final_state.state["logs"][0]["name"] == "FakeListChatModel"
|
||||||
|
|
||||||
|
# Test astream_log with exclude filters
|
||||||
|
final_state = None
|
||||||
|
async for chunk in chain.astream_log(
|
||||||
|
{"question": "What is your name?"}, exclude_names=["FakeListChatModel"]
|
||||||
|
):
|
||||||
|
if final_state is None:
|
||||||
|
final_state = chunk
|
||||||
|
else:
|
||||||
|
final_state += chunk
|
||||||
|
final_state = cast(RunLog, final_state)
|
||||||
|
|
||||||
|
assert final_state.state["final_output"] == final_value
|
||||||
|
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
|
||||||
|
assert len(final_state.state["logs"]) == 4
|
||||||
|
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate"
|
||||||
|
assert final_state.state["logs"][0]["final_output"] == dumpd(
|
||||||
|
prompt.invoke({"question": "What is your name?"})
|
||||||
|
)
|
||||||
|
assert final_state.state["logs"][1]["name"] == "RunnableMap"
|
||||||
|
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [
|
||||||
|
"FakeStreamingListLLM",
|
||||||
|
"RunnablePassthrough",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_map_astream_iterator_input() -> None:
|
async def test_map_astream_iterator_input() -> None:
|
||||||
|
@ -39,8 +39,10 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
|||||||
"PyYAML",
|
"PyYAML",
|
||||||
"SQLAlchemy",
|
"SQLAlchemy",
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
"anyio",
|
||||||
"async-timeout",
|
"async-timeout",
|
||||||
"dataclasses-json",
|
"dataclasses-json",
|
||||||
|
"jsonpatch",
|
||||||
"langsmith",
|
"langsmith",
|
||||||
"numexpr",
|
"numexpr",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
Loading…
Reference in New Issue
Block a user