Accept run name arg for non-chain runs (#10935)

pull/10940/head
Nuno Campos 11 months ago committed by GitHub
parent aac2d4dcef
commit 3d5e92e3ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -102,6 +102,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run."""
@ -122,6 +123,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order,
run_type="llm",
tags=tags or [],
name=name,
)
self._start_trace(llm_run)
self._on_llm_start(llm_run)
@ -335,6 +337,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for a tool run."""
@ -356,6 +359,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_runs=[],
run_type="tool",
tags=tags or [],
name=name,
)
self._start_trace(tool_run)
self._on_tool_start(tool_run)
@ -406,6 +410,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Run when Retriever starts running."""
@ -416,7 +421,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
kwargs.update({"metadata": metadata})
retrieval_run = Run(
id=run_id,
name="Retriever",
name=name or "Retriever",
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"query": query},

@ -98,6 +98,7 @@ class LangChainTracer(BaseTracer):
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
@ -118,6 +119,7 @@ class LangChainTracer(BaseTracer):
child_execution_order=execution_order,
run_type="llm",
tags=tags,
name=name,
)
self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run)

@ -139,6 +139,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
).generations[0][0],
).message,
@ -165,6 +166,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
return cast(
@ -197,7 +199,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
dumpd(self), [messages], invocation_params=params, options=options
dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
)
try:
generation: Optional[ChatGenerationChunk] = None
@ -244,7 +250,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
self.metadata,
)
(run_manager,) = await callback_manager.on_chat_model_start(
dumpd(self), [messages], invocation_params=params, options=options
dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
)
try:
generation: Optional[ChatGenerationChunk] = None
@ -298,6 +308,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
@ -314,7 +325,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
self.metadata,
)
run_managers = callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options
dumpd(self),
messages,
invocation_params=params,
options=options,
name=run_name,
)
results = []
for i, m in enumerate(messages):
@ -354,6 +369,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
@ -371,7 +387,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
)
run_managers = await callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options
dumpd(self),
messages,
invocation_params=params,
options=options,
name=run_name,
)
results = await asyncio.gather(

@ -228,6 +228,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
.generations[0][0]
@ -255,6 +256,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
return llm_result.generations[0][0].text
@ -280,6 +282,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
run_name=[c.get("run_name") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
@ -328,6 +331,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
run_name=[c.get("run_name") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
@ -375,7 +379,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata,
)
(run_manager,) = callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=config.get("run_name"),
)
try:
generation: Optional[GenerationChunk] = None
@ -422,7 +430,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata,
)
(run_manager,) = await callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=config.get("run_name"),
)
try:
generation: Optional[GenerationChunk] = None
@ -544,6 +556,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
*,
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
@ -569,11 +582,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
assert metadata is None or (
isinstance(metadata, list) and len(metadata) == len(prompts)
)
assert run_name is None or (
isinstance(run_name, list) and len(run_name) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
)
run_name_list = run_name or cast(
List[Optional[str]], ([None] * len(prompts))
)
callback_managers = [
CallbackManager.configure(
callback,
@ -599,6 +618,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata,
)
] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
params = self.dict()
params["stop"] = stop
@ -620,9 +640,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
run_managers = [
callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=run_name,
)[0]
for callback_manager, prompt in zip(callback_managers, prompts)
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
)
]
output = self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
@ -635,6 +661,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
[prompts[idx]],
invocation_params=params,
options=options,
name=run_name_list[idx],
)[0]
for idx in missing_prompt_idxs
]
@ -702,6 +729,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
*,
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
@ -718,11 +746,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
assert metadata is None or (
isinstance(metadata, list) and len(metadata) == len(prompts)
)
assert run_name is None or (
isinstance(run_name, list) and len(run_name) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
)
run_name_list = run_name or cast(
List[Optional[str]], ([None] * len(prompts))
)
callback_managers = [
AsyncCallbackManager.configure(
callback,
@ -748,6 +782,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata,
)
] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
params = self.dict()
params["stop"] = stop
@ -770,9 +805,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_managers = await asyncio.gather(
*[
callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=run_name,
)
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
)
for callback_manager, prompt in zip(callback_managers, prompts)
]
)
run_managers = [r[0] for r in run_managers]
@ -788,6 +829,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
[prompts[idx]],
invocation_params=params,
options=options,
name=run_name_list[idx],
)
for idx in missing_prompt_idxs
]

@ -113,6 +113,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
)
async def ainvoke(
@ -131,6 +132,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
)
@abstractmethod
@ -164,6 +166,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Retrieve documents relevant to a query.
@ -193,6 +196,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
run_manager = callback_manager.on_retriever_start(
dumpd(self),
query,
name=run_name,
**kwargs,
)
try:
@ -220,6 +224,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
@ -249,6 +254,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
run_manager = await callback_manager.on_retriever_start(
dumpd(self),
query,
name=run_name,
**kwargs,
)
try:

@ -199,6 +199,7 @@ class ChildTool(BaseTool):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
@ -218,6 +219,7 @@ class ChildTool(BaseTool):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
@ -297,6 +299,7 @@ class ChildTool(BaseTool):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
@ -320,6 +323,7 @@ class ChildTool(BaseTool):
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
name=run_name,
**kwargs,
)
try:
@ -370,6 +374,7 @@ class ChildTool(BaseTool):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
@ -392,6 +397,7 @@ class ChildTool(BaseTool):
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
name=run_name,
**kwargs,
)
try:

Loading…
Cancel
Save