Accept run name arg for non-chain runs (#10935)

This commit is contained in:
Nuno Campos 2023-09-22 16:41:25 +01:00 committed by GitHub
parent aac2d4dcef
commit 3d5e92e3ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 92 additions and 11 deletions

View File

@ -102,6 +102,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
"""Start a trace for an LLM run.""" """Start a trace for an LLM run."""
@ -122,6 +123,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order, child_execution_order=execution_order,
run_type="llm", run_type="llm",
tags=tags or [], tags=tags or [],
name=name,
) )
self._start_trace(llm_run) self._start_trace(llm_run)
self._on_llm_start(llm_run) self._on_llm_start(llm_run)
@ -335,6 +337,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
"""Start a trace for a tool run.""" """Start a trace for a tool run."""
@ -356,6 +359,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_runs=[], child_runs=[],
run_type="tool", run_type="tool",
tags=tags or [], tags=tags or [],
name=name,
) )
self._start_trace(tool_run) self._start_trace(tool_run)
self._on_tool_start(tool_run) self._on_tool_start(tool_run)
@ -406,6 +410,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
"""Run when Retriever starts running.""" """Run when Retriever starts running."""
@ -416,7 +421,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
kwargs.update({"metadata": metadata}) kwargs.update({"metadata": metadata})
retrieval_run = Run( retrieval_run = Run(
id=run_id, id=run_id,
name="Retriever", name=name or "Retriever",
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
serialized=serialized, serialized=serialized,
inputs={"query": query}, inputs={"query": query},

View File

@ -98,6 +98,7 @@ class LangChainTracer(BaseTracer):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Start a trace for an LLM run.""" """Start a trace for an LLM run."""
@ -118,6 +119,7 @@ class LangChainTracer(BaseTracer):
child_execution_order=execution_order, child_execution_order=execution_order,
run_type="llm", run_type="llm",
tags=tags, tags=tags,
name=name,
) )
self._start_trace(chat_model_run) self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run) self._on_chat_model_start(chat_model_run)

View File

@ -139,6 +139,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs, **kwargs,
).generations[0][0], ).generations[0][0],
).message, ).message,
@ -165,6 +166,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs, **kwargs,
) )
return cast( return cast(
@ -197,7 +199,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
self.metadata, self.metadata,
) )
(run_manager,) = callback_manager.on_chat_model_start( (run_manager,) = callback_manager.on_chat_model_start(
dumpd(self), [messages], invocation_params=params, options=options dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
) )
try: try:
generation: Optional[ChatGenerationChunk] = None generation: Optional[ChatGenerationChunk] = None
@ -244,7 +250,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
self.metadata, self.metadata,
) )
(run_manager,) = await callback_manager.on_chat_model_start( (run_manager,) = await callback_manager.on_chat_model_start(
dumpd(self), [messages], invocation_params=params, options=options dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
) )
try: try:
generation: Optional[ChatGenerationChunk] = None generation: Optional[ChatGenerationChunk] = None
@ -298,6 +308,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
@ -314,7 +325,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
self.metadata, self.metadata,
) )
run_managers = callback_manager.on_chat_model_start( run_managers = callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options dumpd(self),
messages,
invocation_params=params,
options=options,
name=run_name,
) )
results = [] results = []
for i, m in enumerate(messages): for i, m in enumerate(messages):
@ -354,6 +369,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
@ -371,7 +387,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) )
run_managers = await callback_manager.on_chat_model_start( run_managers = await callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options dumpd(self),
messages,
invocation_params=params,
options=options,
name=run_name,
) )
results = await asyncio.gather( results = await asyncio.gather(

View File

@ -228,6 +228,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs, **kwargs,
) )
.generations[0][0] .generations[0][0]
@ -255,6 +256,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs, **kwargs,
) )
return llm_result.generations[0][0].text return llm_result.generations[0][0].text
@ -280,6 +282,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=[c.get("callbacks") for c in config], callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config], tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config], metadata=[c.get("metadata") for c in config],
run_name=[c.get("run_name") for c in config],
**kwargs, **kwargs,
) )
return [g[0].text for g in llm_result.generations] return [g[0].text for g in llm_result.generations]
@ -328,6 +331,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=[c.get("callbacks") for c in config], callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config], tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config], metadata=[c.get("metadata") for c in config],
run_name=[c.get("run_name") for c in config],
**kwargs, **kwargs,
) )
return [g[0].text for g in llm_result.generations] return [g[0].text for g in llm_result.generations]
@ -375,7 +379,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata, self.metadata,
) )
(run_manager,) = callback_manager.on_llm_start( (run_manager,) = callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=config.get("run_name"),
) )
try: try:
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
@ -422,7 +430,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata, self.metadata,
) )
(run_manager,) = await callback_manager.on_llm_start( (run_manager,) = await callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=config.get("run_name"),
) )
try: try:
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
@ -544,6 +556,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
*, *,
tags: Optional[Union[List[str], List[List[str]]]] = None, tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
@ -569,11 +582,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
assert metadata is None or ( assert metadata is None or (
isinstance(metadata, list) and len(metadata) == len(prompts) isinstance(metadata, list) and len(metadata) == len(prompts)
) )
assert run_name is None or (
isinstance(run_name, list) and len(run_name) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks) callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
metadata_list = cast( metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
) )
run_name_list = run_name or cast(
List[Optional[str]], ([None] * len(prompts))
)
callback_managers = [ callback_managers = [
CallbackManager.configure( CallbackManager.configure(
callback, callback,
@ -599,6 +618,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata, self.metadata,
) )
] * len(prompts) ] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
@ -620,9 +640,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) )
run_managers = [ run_managers = [
callback_manager.on_llm_start( callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=run_name,
)[0] )[0]
for callback_manager, prompt in zip(callback_managers, prompts) for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
)
] ]
output = self._generate_helper( output = self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs prompts, stop, run_managers, bool(new_arg_supported), **kwargs
@ -635,6 +661,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
[prompts[idx]], [prompts[idx]],
invocation_params=params, invocation_params=params,
options=options, options=options,
name=run_name_list[idx],
)[0] )[0]
for idx in missing_prompt_idxs for idx in missing_prompt_idxs
] ]
@ -702,6 +729,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
*, *,
tags: Optional[Union[List[str], List[List[str]]]] = None, tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
@ -718,11 +746,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
assert metadata is None or ( assert metadata is None or (
isinstance(metadata, list) and len(metadata) == len(prompts) isinstance(metadata, list) and len(metadata) == len(prompts)
) )
assert run_name is None or (
isinstance(run_name, list) and len(run_name) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks) callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
metadata_list = cast( metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
) )
run_name_list = run_name or cast(
List[Optional[str]], ([None] * len(prompts))
)
callback_managers = [ callback_managers = [
AsyncCallbackManager.configure( AsyncCallbackManager.configure(
callback, callback,
@ -748,6 +782,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata, self.metadata,
) )
] * len(prompts) ] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
@ -770,9 +805,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_managers = await asyncio.gather( run_managers = await asyncio.gather(
*[ *[
callback_manager.on_llm_start( callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=run_name,
)
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
) )
for callback_manager, prompt in zip(callback_managers, prompts)
] ]
) )
run_managers = [r[0] for r in run_managers] run_managers = [r[0] for r in run_managers]
@ -788,6 +829,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
[prompts[idx]], [prompts[idx]],
invocation_params=params, invocation_params=params,
options=options, options=options,
name=run_name_list[idx],
) )
for idx in missing_prompt_idxs for idx in missing_prompt_idxs
] ]

View File

@ -113,6 +113,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"),
) )
async def ainvoke( async def ainvoke(
@ -131,6 +132,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"),
) )
@abstractmethod @abstractmethod
@ -164,6 +166,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Retrieve documents relevant to a query. """Retrieve documents relevant to a query.
@ -193,6 +196,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
run_manager = callback_manager.on_retriever_start( run_manager = callback_manager.on_retriever_start(
dumpd(self), dumpd(self),
query, query,
name=run_name,
**kwargs, **kwargs,
) )
try: try:
@ -220,6 +224,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Asynchronously get documents relevant to a query. """Asynchronously get documents relevant to a query.
@ -249,6 +254,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
run_manager = await callback_manager.on_retriever_start( run_manager = await callback_manager.on_retriever_start(
dumpd(self), dumpd(self),
query, query,
name=run_name,
**kwargs, **kwargs,
) )
try: try:

View File

@ -199,6 +199,7 @@ class ChildTool(BaseTool):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs, **kwargs,
) )
@ -218,6 +219,7 @@ class ChildTool(BaseTool):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs, **kwargs,
) )
@ -297,6 +299,7 @@ class ChildTool(BaseTool):
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool.""" """Run the tool."""
@ -320,6 +323,7 @@ class ChildTool(BaseTool):
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color, color=start_color,
name=run_name,
**kwargs, **kwargs,
) )
try: try:
@ -370,6 +374,7 @@ class ChildTool(BaseTool):
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
@ -392,6 +397,7 @@ class ChildTool(BaseTool):
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color, color=start_color,
name=run_name,
**kwargs, **kwargs,
) )
try: try: