From 3d5e92e3ef0506f8f5b937767c5a58b584d6e58d Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 22 Sep 2023 16:41:25 +0100 Subject: [PATCH] Accept run name arg for non-chain runs (#10935) --- .../langchain/callbacks/tracers/base.py | 7 ++- .../langchain/callbacks/tracers/langchain.py | 2 + libs/langchain/langchain/chat_models/base.py | 28 ++++++++-- libs/langchain/langchain/llms/base.py | 54 ++++++++++++++++--- libs/langchain/langchain/schema/retriever.py | 6 +++ libs/langchain/langchain/tools/base.py | 6 +++ 6 files changed, 92 insertions(+), 11 deletions(-) diff --git a/libs/langchain/langchain/callbacks/tracers/base.py b/libs/langchain/langchain/callbacks/tracers/base.py index 5b1427be43..f270a5ec55 100644 --- a/libs/langchain/langchain/callbacks/tracers/base.py +++ b/libs/langchain/langchain/callbacks/tracers/base.py @@ -102,6 +102,7 @@ class BaseTracer(BaseCallbackHandler, ABC): tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, **kwargs: Any, ) -> Run: """Start a trace for an LLM run.""" @@ -122,6 +123,7 @@ class BaseTracer(BaseCallbackHandler, ABC): child_execution_order=execution_order, run_type="llm", tags=tags or [], + name=name, ) self._start_trace(llm_run) self._on_llm_start(llm_run) @@ -335,6 +337,7 @@ class BaseTracer(BaseCallbackHandler, ABC): tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, **kwargs: Any, ) -> Run: """Start a trace for a tool run.""" @@ -356,6 +359,7 @@ class BaseTracer(BaseCallbackHandler, ABC): child_runs=[], run_type="tool", tags=tags or [], + name=name, ) self._start_trace(tool_run) self._on_tool_start(tool_run) @@ -406,6 +410,7 @@ class BaseTracer(BaseCallbackHandler, ABC): parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, **kwargs: Any, ) -> Run: """Run when Retriever starts running.""" @@ -416,7 +421,7 @@ class BaseTracer(BaseCallbackHandler, ABC): kwargs.update({"metadata": metadata}) retrieval_run = Run( id=run_id, - name="Retriever", + name=name or "Retriever", parent_run_id=parent_run_id, serialized=serialized, inputs={"query": query}, diff --git a/libs/langchain/langchain/callbacks/tracers/langchain.py b/libs/langchain/langchain/callbacks/tracers/langchain.py index 0e6393c78b..07cde9e568 100644 --- a/libs/langchain/langchain/callbacks/tracers/langchain.py +++ b/libs/langchain/langchain/callbacks/tracers/langchain.py @@ -98,6 +98,7 @@ class LangChainTracer(BaseTracer): tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, **kwargs: Any, ) -> None: """Start a trace for an LLM run.""" @@ -118,6 +119,7 @@ class LangChainTracer(BaseTracer): child_execution_order=execution_order, run_type="llm", tags=tags, + name=name, ) self._start_trace(chat_model_run) self._on_chat_model_start(chat_model_run) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 094418c94b..ebe1ef2933 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -139,6 +139,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_name=config.get("run_name"), **kwargs, ).generations[0][0], ).message, @@ -165,6 +166,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_name=config.get("run_name"), **kwargs, ) return cast( @@ -197,7 +199,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): self.metadata, ) (run_manager,) = callback_manager.on_chat_model_start( - dumpd(self), [messages], invocation_params=params, options=options + dumpd(self), + [messages], + invocation_params=params, + options=options, + name=config.get("run_name"), ) try: generation: Optional[ChatGenerationChunk] = None @@ -244,7 +250,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): self.metadata, ) (run_manager,) = await callback_manager.on_chat_model_start( - dumpd(self), [messages], invocation_params=params, options=options + dumpd(self), + [messages], + invocation_params=params, + options=options, + name=config.get("run_name"), ) try: generation: Optional[ChatGenerationChunk] = None @@ -298,6 +308,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, **kwargs: Any, ) -> LLMResult: """Top Level call""" @@ -314,7 +325,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): self.metadata, ) run_managers = callback_manager.on_chat_model_start( - dumpd(self), messages, invocation_params=params, options=options + dumpd(self), + messages, + invocation_params=params, + options=options, + name=run_name, ) results = [] for i, m in enumerate(messages): @@ -354,6 +369,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, **kwargs: Any, ) -> LLMResult: """Top Level call""" @@ -371,7 +387,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): ) run_managers = await callback_manager.on_chat_model_start( - dumpd(self), messages, invocation_params=params, options=options + dumpd(self), + messages, + invocation_params=params, + options=options, + name=run_name, ) results = await asyncio.gather( diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 0c9c3158ec..bfaa85fdeb 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -228,6 +228,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_name=config.get("run_name"), **kwargs, ) .generations[0][0] @@ -255,6 +256,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_name=config.get("run_name"), **kwargs, ) return llm_result.generations[0][0].text @@ -280,6 +282,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks=[c.get("callbacks") for c in config], tags=[c.get("tags") for c in config], metadata=[c.get("metadata") for c in config], + run_name=[c.get("run_name") for c in config], **kwargs, ) return [g[0].text for g in llm_result.generations] @@ -328,6 +331,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks=[c.get("callbacks") for c in config], tags=[c.get("tags") for c in config], metadata=[c.get("metadata") for c in config], + run_name=[c.get("run_name") for c in config], **kwargs, ) return [g[0].text for g in llm_result.generations] @@ -375,7 +379,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): self.metadata, ) (run_manager,) = callback_manager.on_llm_start( - dumpd(self), [prompt], invocation_params=params, options=options + dumpd(self), + [prompt], + invocation_params=params, + options=options, + name=config.get("run_name"), ) try: generation: Optional[GenerationChunk] = None @@ -422,7 +430,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): self.metadata, ) (run_manager,) = await callback_manager.on_llm_start( - dumpd(self), [prompt], invocation_params=params, options=options + dumpd(self), + [prompt], + invocation_params=params, + options=options, + name=config.get("run_name"), ) try: generation: Optional[GenerationChunk] = None @@ -544,6 +556,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): *, tags: Optional[Union[List[str], List[List[str]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + run_name: Optional[Union[str, List[str]]] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" @@ -569,11 +582,17 @@ class BaseLLM(BaseLanguageModel[str], ABC): assert metadata is None or ( isinstance(metadata, list) and len(metadata) == len(prompts) ) + assert run_name is None or ( + isinstance(run_name, list) and len(run_name) == len(prompts) + ) callbacks = cast(List[Callbacks], callbacks) tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) metadata_list = cast( List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) ) + run_name_list = run_name or cast( + List[Optional[str]], ([None] * len(prompts)) + ) callback_managers = [ CallbackManager.configure( callback, @@ -599,6 +618,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): self.metadata, ) ] * len(prompts) + run_name_list = [cast(Optional[str], run_name)] * len(prompts) params = self.dict() params["stop"] = stop @@ -620,9 +640,15 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) run_managers = [ callback_manager.on_llm_start( - dumpd(self), [prompt], invocation_params=params, options=options + dumpd(self), + [prompt], + invocation_params=params, + options=options, + name=run_name, )[0] - for callback_manager, prompt in zip(callback_managers, prompts) + for callback_manager, prompt, run_name in zip( + callback_managers, prompts, run_name_list + ) ] output = self._generate_helper( prompts, stop, run_managers, bool(new_arg_supported), **kwargs @@ -635,6 +661,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): [prompts[idx]], invocation_params=params, options=options, + name=run_name_list[idx], )[0] for idx in missing_prompt_idxs ] @@ -702,6 +729,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): *, tags: Optional[Union[List[str], List[List[str]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + run_name: Optional[Union[str, List[str]]] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" @@ -718,11 +746,17 @@ class BaseLLM(BaseLanguageModel[str], ABC): assert metadata is None or ( isinstance(metadata, list) and len(metadata) == len(prompts) ) + assert run_name is None or ( + isinstance(run_name, list) and len(run_name) == len(prompts) + ) callbacks = cast(List[Callbacks], callbacks) tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) metadata_list = cast( List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) ) + run_name_list = run_name or cast( + List[Optional[str]], ([None] * len(prompts)) + ) callback_managers = [ AsyncCallbackManager.configure( callback, @@ -748,6 +782,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): self.metadata, ) ] * len(prompts) + run_name_list = [cast(Optional[str], run_name)] * len(prompts) params = self.dict() params["stop"] = stop @@ -770,9 +805,15 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_managers = await asyncio.gather( *[ callback_manager.on_llm_start( - dumpd(self), [prompt], invocation_params=params, options=options + dumpd(self), + [prompt], + invocation_params=params, + options=options, + name=run_name, + ) + for callback_manager, prompt, run_name in zip( + callback_managers, prompts, run_name_list ) - for callback_manager, prompt in zip(callback_managers, prompts) ] ) run_managers = [r[0] for r in run_managers] @@ -788,6 +829,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): [prompts[idx]], invocation_params=params, options=options, + name=run_name_list[idx], ) for idx in missing_prompt_idxs ] diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 5da50e1497..04c2835434 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -113,6 +113,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_name=config.get("run_name"), ) async def ainvoke( @@ -131,6 +132,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_name=config.get("run_name"), ) @abstractmethod @@ -164,6 +166,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, **kwargs: Any, ) -> List[Document]: """Retrieve documents relevant to a query. @@ -193,6 +196,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): run_manager = callback_manager.on_retriever_start( dumpd(self), query, + name=run_name, **kwargs, ) try: @@ -220,6 +224,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, **kwargs: Any, ) -> List[Document]: """Asynchronously get documents relevant to a query. @@ -249,6 +254,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): run_manager = await callback_manager.on_retriever_start( dumpd(self), query, + name=run_name, **kwargs, ) try: diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index b6fac8bf9d..69ce73f603 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -199,6 +199,7 @@ class ChildTool(BaseTool): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_name=config.get("run_name"), **kwargs, ) @@ -218,6 +219,7 @@ class ChildTool(BaseTool): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_name=config.get("run_name"), **kwargs, ) @@ -297,6 +299,7 @@ class ChildTool(BaseTool): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, **kwargs: Any, ) -> Any: """Run the tool.""" @@ -320,6 +323,7 @@ class ChildTool(BaseTool): {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, + name=run_name, **kwargs, ) try: @@ -370,6 +374,7 @@ class ChildTool(BaseTool): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, **kwargs: Any, ) -> Any: """Run the tool asynchronously.""" @@ -392,6 +397,7 @@ class ChildTool(BaseTool): {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, + name=run_name, **kwargs, ) try: