Add Tags for LLMs (#6229)

- [x] Add tracing tags to LLMs + Chat Models (both inheritable and
local)
- [x] Add tags for the run_on_dataset helper function(s)
This commit is contained in:
Zander Chase 2023-06-15 11:24:11 -07:00 committed by GitHub
parent 8e1a7a8646
commit ae76e473e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 27 deletions

View File

@ -39,6 +39,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
"""Whether to print out response text.""" """Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True) callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
@root_validator() @root_validator()
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:
@ -65,6 +67,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
@ -74,7 +78,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
options = {"stop": stop} options = {"stop": stop}
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
) )
run_manager = callback_manager.on_chat_model_start( run_manager = callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options dumpd(self), messages, invocation_params=params, options=options
@ -106,6 +114,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
@ -114,7 +124,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
options = {"stop": stop} options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
) )
run_manager = await callback_manager.on_chat_model_start( run_manager = await callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options dumpd(self), messages, invocation_params=params, options=options

View File

@ -5,7 +5,16 @@ import asyncio
import functools import functools
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Coroutine, Dict, Iterator, List, Optional, Union from typing import (
Any,
Callable,
Coroutine,
Dict,
Iterator,
List,
Optional,
Union,
)
from langchainplus_sdk import LangChainPlusClient from langchainplus_sdk import LangChainPlusClient
from langchainplus_sdk.schemas import Example from langchainplus_sdk.schemas import Example
@ -104,6 +113,8 @@ async def _arun_llm(
llm: BaseLanguageModel, llm: BaseLanguageModel,
inputs: Dict[str, Any], inputs: Dict[str, Any],
langchain_tracer: Optional[LangChainTracer], langchain_tracer: Optional[LangChainTracer],
*,
tags: Optional[List[str]] = None,
) -> Union[LLMResult, ChatResult]: ) -> Union[LLMResult, ChatResult]:
callbacks: Optional[List[BaseCallbackHandler]] = ( callbacks: Optional[List[BaseCallbackHandler]] = (
[langchain_tracer] if langchain_tracer else None [langchain_tracer] if langchain_tracer else None
@ -111,21 +122,27 @@ async def _arun_llm(
if isinstance(llm, BaseLLM): if isinstance(llm, BaseLLM):
try: try:
llm_prompts = _get_prompts(inputs) llm_prompts = _get_prompts(inputs)
llm_output = await llm.agenerate(llm_prompts, callbacks=callbacks) llm_output = await llm.agenerate(
llm_prompts, callbacks=callbacks, tags=tags
)
except InputFormatError: except InputFormatError:
llm_messages = _get_messages(inputs) llm_messages = _get_messages(inputs)
buffer_strings = [get_buffer_string(messages) for messages in llm_messages] buffer_strings = [get_buffer_string(messages) for messages in llm_messages]
llm_output = await llm.agenerate(buffer_strings, callbacks=callbacks) llm_output = await llm.agenerate(
buffer_strings, callbacks=callbacks, tags=tags
)
elif isinstance(llm, BaseChatModel): elif isinstance(llm, BaseChatModel):
try: try:
messages = _get_messages(inputs) messages = _get_messages(inputs)
llm_output = await llm.agenerate(messages, callbacks=callbacks) llm_output = await llm.agenerate(messages, callbacks=callbacks, tags=tags)
except InputFormatError: except InputFormatError:
prompts = _get_prompts(inputs) prompts = _get_prompts(inputs)
converted_messages: List[List[BaseMessage]] = [ converted_messages: List[List[BaseMessage]] = [
[HumanMessage(content=prompt)] for prompt in prompts [HumanMessage(content=prompt)] for prompt in prompts
] ]
llm_output = await llm.agenerate(converted_messages, callbacks=callbacks) llm_output = await llm.agenerate(
converted_messages, callbacks=callbacks, tags=tags
)
else: else:
raise ValueError(f"Unsupported LLM type {type(llm)}") raise ValueError(f"Unsupported LLM type {type(llm)}")
return llm_output return llm_output
@ -136,6 +153,8 @@ async def _arun_llm_or_chain(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int, n_repetitions: int,
langchain_tracer: Optional[LangChainTracer], langchain_tracer: Optional[LangChainTracer],
*,
tags: Optional[List[str]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain asynchronously.""" """Run the chain asynchronously."""
if langchain_tracer is not None: if langchain_tracer is not None:
@ -150,11 +169,16 @@ async def _arun_llm_or_chain(
try: try:
if isinstance(llm_or_chain_factory, BaseLanguageModel): if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = await _arun_llm( output: Any = await _arun_llm(
llm_or_chain_factory, example.inputs, langchain_tracer llm_or_chain_factory,
example.inputs,
langchain_tracer,
tags=tags,
) )
else: else:
chain = llm_or_chain_factory() chain = llm_or_chain_factory()
output = await chain.acall(example.inputs, callbacks=callbacks) output = await chain.acall(
example.inputs, callbacks=callbacks, tags=tags
)
outputs.append(output) outputs.append(output)
except Exception as e: except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}") logger.warning(f"Chain failed for example {example.id}. Error: {e}")
@ -230,6 +254,7 @@ async def arun_on_examples(
num_repetitions: int = 1, num_repetitions: int = 1,
session_name: Optional[str] = None, session_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run the chain on examples and store traces to the specified session name. Run the chain on examples and store traces to the specified session name.
@ -245,6 +270,7 @@ async def arun_on_examples(
intervals. intervals.
session_name: Session name to use when tracing runs. session_name: Session name to use when tracing runs.
verbose: Whether to print progress. verbose: Whether to print progress.
tags: Tags to add to the traces.
Returns: Returns:
A dictionary mapping example ids to the model outputs. A dictionary mapping example ids to the model outputs.
@ -260,6 +286,7 @@ async def arun_on_examples(
llm_or_chain_factory, llm_or_chain_factory,
num_repetitions, num_repetitions,
tracer, tracer,
tags=tags,
) )
results[str(example.id)] = result results[str(example.id)] = result
job_state["num_processed"] += 1 job_state["num_processed"] += 1
@ -282,12 +309,14 @@ def run_llm(
llm: BaseLanguageModel, llm: BaseLanguageModel,
inputs: Dict[str, Any], inputs: Dict[str, Any],
callbacks: Callbacks, callbacks: Callbacks,
*,
tags: Optional[List[str]] = None,
) -> Union[LLMResult, ChatResult]: ) -> Union[LLMResult, ChatResult]:
"""Run the language model on the example.""" """Run the language model on the example."""
if isinstance(llm, BaseLLM): if isinstance(llm, BaseLLM):
try: try:
llm_prompts = _get_prompts(inputs) llm_prompts = _get_prompts(inputs)
llm_output = llm.generate(llm_prompts, callbacks=callbacks) llm_output = llm.generate(llm_prompts, callbacks=callbacks, tags=tags)
except InputFormatError: except InputFormatError:
llm_messages = _get_messages(inputs) llm_messages = _get_messages(inputs)
buffer_strings = [get_buffer_string(messages) for messages in llm_messages] buffer_strings = [get_buffer_string(messages) for messages in llm_messages]
@ -295,13 +324,15 @@ def run_llm(
elif isinstance(llm, BaseChatModel): elif isinstance(llm, BaseChatModel):
try: try:
messages = _get_messages(inputs) messages = _get_messages(inputs)
llm_output = llm.generate(messages, callbacks=callbacks) llm_output = llm.generate(messages, callbacks=callbacks, tags=tags)
except InputFormatError: except InputFormatError:
prompts = _get_prompts(inputs) prompts = _get_prompts(inputs)
converted_messages: List[List[BaseMessage]] = [ converted_messages: List[List[BaseMessage]] = [
[HumanMessage(content=prompt)] for prompt in prompts [HumanMessage(content=prompt)] for prompt in prompts
] ]
llm_output = llm.generate(converted_messages, callbacks=callbacks) llm_output = llm.generate(
converted_messages, callbacks=callbacks, tags=tags
)
else: else:
raise ValueError(f"Unsupported LLM type {type(llm)}") raise ValueError(f"Unsupported LLM type {type(llm)}")
return llm_output return llm_output
@ -312,6 +343,8 @@ def run_llm_or_chain(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int, n_repetitions: int,
langchain_tracer: Optional[LangChainTracer] = None, langchain_tracer: Optional[LangChainTracer] = None,
*,
tags: Optional[List[str]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain synchronously.""" """Run the chain synchronously."""
if langchain_tracer is not None: if langchain_tracer is not None:
@ -325,10 +358,12 @@ def run_llm_or_chain(
for _ in range(n_repetitions): for _ in range(n_repetitions):
try: try:
if isinstance(llm_or_chain_factory, BaseLanguageModel): if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = run_llm(llm_or_chain_factory, example.inputs, callbacks) output: Any = run_llm(
llm_or_chain_factory, example.inputs, callbacks, tags=tags
)
else: else:
chain = llm_or_chain_factory() chain = llm_or_chain_factory()
output = chain(example.inputs, callbacks=callbacks) output = chain(example.inputs, callbacks=callbacks, tags=tags)
outputs.append(output) outputs.append(output)
except Exception as e: except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}") logger.warning(f"Chain failed for example {example.id}. Error: {e}")
@ -345,6 +380,7 @@ def run_on_examples(
num_repetitions: int = 1, num_repetitions: int = 1,
session_name: Optional[str] = None, session_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the chain on examples and store traces to the specified session name. """Run the chain on examples and store traces to the specified session name.
@ -359,6 +395,7 @@ def run_on_examples(
intervals. intervals.
session_name: Session name to use when tracing runs. session_name: Session name to use when tracing runs.
verbose: Whether to print progress. verbose: Whether to print progress.
tags: Tags to add to the run traces.
Returns: Returns:
A dictionary mapping example ids to the model outputs. A dictionary mapping example ids to the model outputs.
""" """
@ -370,6 +407,7 @@ def run_on_examples(
llm_or_chain_factory, llm_or_chain_factory,
num_repetitions, num_repetitions,
langchain_tracer=tracer, langchain_tracer=tracer,
tags=tags,
) )
if verbose: if verbose:
print(f"{i+1} processed", flush=True, end="\r") print(f"{i+1} processed", flush=True, end="\r")
@ -401,6 +439,7 @@ async def arun_on_dataset(
session_name: Optional[str] = None, session_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
client: Optional[LangChainPlusClient] = None, client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run the chain on a dataset and store traces to the specified session name. Run the chain on a dataset and store traces to the specified session name.
@ -420,6 +459,7 @@ async def arun_on_dataset(
verbose: Whether to print progress. verbose: Whether to print progress.
client: Client to use to read the dataset. If not provided, a new client: Client to use to read the dataset. If not provided, a new
client will be created using the credentials in the environment. client will be created using the credentials in the environment.
tags: Tags to add to each run in the sesssion.
Returns: Returns:
A dictionary containing the run's session name and the resulting model outputs. A dictionary containing the run's session name and the resulting model outputs.
@ -436,6 +476,7 @@ async def arun_on_dataset(
num_repetitions=num_repetitions, num_repetitions=num_repetitions,
session_name=session_name, session_name=session_name,
verbose=verbose, verbose=verbose,
tags=tags,
) )
return { return {
"session_name": session_name, "session_name": session_name,
@ -451,6 +492,7 @@ def run_on_dataset(
session_name: Optional[str] = None, session_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
client: Optional[LangChainPlusClient] = None, client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the chain on a dataset and store traces to the specified session name. """Run the chain on a dataset and store traces to the specified session name.
@ -468,6 +510,7 @@ def run_on_dataset(
verbose: Whether to print progress. verbose: Whether to print progress.
client: Client to use to access the dataset. If None, a new client client: Client to use to access the dataset. If None, a new client
will be created using the credentials in the environment. will be created using the credentials in the environment.
tags: Tags to add to each run in the sesssion.
Returns: Returns:
A dictionary containing the run's session name and the resulting model outputs. A dictionary containing the run's session name and the resulting model outputs.
@ -482,6 +525,7 @@ def run_on_dataset(
num_repetitions=num_repetitions, num_repetitions=num_repetitions,
session_name=session_name, session_name=session_name,
verbose=verbose, verbose=verbose,
tags=tags,
) )
return { return {
"session_name": session_name, "session_name": session_name,

View File

@ -369,6 +369,7 @@
"\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mclient\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[LangChainPlusClient]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mclient\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[LangChainPlusClient]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[List[str]]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m\n", "\u001b[0;31mDocstring:\u001b[0m\n",
"Run the chain on a dataset and store traces to the specified session name.\n", "Run the chain on a dataset and store traces to the specified session name.\n",
@ -388,6 +389,7 @@
" verbose: Whether to print progress.\n", " verbose: Whether to print progress.\n",
" client: Client to use to read the dataset. If not provided, a new\n", " client: Client to use to read the dataset. If not provided, a new\n",
" client will be created using the credentials in the environment.\n", " client will be created using the credentials in the environment.\n",
" tags: Tags to add to each run in the sesssion.\n",
"\n", "\n",
"Returns:\n", "Returns:\n",
" A dictionary containing the run's session name and the resulting model outputs.\n", " A dictionary containing the run's session name and the resulting model outputs.\n",
@ -430,7 +432,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 13,
"id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33", "id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -440,21 +442,21 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Processed examples: 4\r" "Processed examples: 1\r"
] ]
}, },
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Chain failed for example c855f923-4165-4fe0-a909-360749f3f764. Error: Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n" "Chain failed for example b36a82d3-4fb6-4bc4-87df-b7c355742b8e. Error: unknown format from LLM: Sorry, I cannot answer this question as it requires information that is not currently available.\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Processed examples: 5\r" "Processed examples: 6\r"
] ]
} }
], ],
@ -465,6 +467,7 @@
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n", " concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
" verbose=True,\n", " verbose=True,\n",
" client=client,\n", " client=client,\n",
" tags=[\"testing-notebook\", \"turbo\"], # Optional, adds a tag to the resulting chain runs\n",
")\n", ")\n",
"\n", "\n",
"# Sometimes, the agent will error due to parsing issues, incompatible tool inputs, etc.\n", "# Sometimes, the agent will error due to parsing issues, incompatible tool inputs, etc.\n",
@ -486,7 +489,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 14,
"id": "136db492-d6ca-4215-96f9-439c23538241", "id": "136db492-d6ca-4215-96f9-439c23538241",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -501,7 +504,7 @@
"LangChainPlusClient (API URL: https://dev.api.langchain.plus)" "LangChainPlusClient (API URL: https://dev.api.langchain.plus)"
] ]
}, },
"execution_count": 13, "execution_count": 14,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -534,7 +537,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 15,
"id": "35db4025-9183-4e5f-ba14-0b1b380f49c7", "id": "35db4025-9183-4e5f-ba14-0b1b380f49c7",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -565,7 +568,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "9989f6507cd04ea7a09ea3c5723dc984", "model_id": "5fce1ce42a8c4110b7d12443948ac697",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@ -592,12 +595,26 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 17,
"id": "8696f167-dc75-4ef8-8bb3-ac1ce8324f30", "id": "8696f167-dc75-4ef8-8bb3-ac1ce8324f30",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [], "outputs": [
{
"data": {
"text/html": [
"<a href=\"https://dev.langchain.plus\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
],
"text/plain": [
"LangChainPlusClient (API URL: https://dev.api.langchain.plus)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"client" "client"
] ]

View File

@ -79,6 +79,8 @@ class BaseLLM(BaseLanguageModel, ABC):
"""Whether to print out response text.""" """Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True) callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -155,6 +157,8 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
@ -176,7 +180,7 @@ class BaseLLM(BaseLanguageModel, ABC):
) = get_prompts(params, prompts) ) = get_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache disregard_cache = self.cache is not None and not self.cache
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose, tags, self.tags
) )
new_arg_supported = inspect.signature(self._generate).parameters.get( new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager" "run_manager"
@ -241,6 +245,8 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
@ -255,7 +261,7 @@ class BaseLLM(BaseLanguageModel, ABC):
) = get_prompts(params, prompts) ) = get_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache disregard_cache = self.cache is not None and not self.cache
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose, tags, self.tags
) )
new_arg_supported = inspect.signature(self._agenerate).parameters.get( new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager" "run_manager"

View File

@ -1,7 +1,7 @@
"""Test the LangChain+ client.""" """Test the LangChain+ client."""
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Optional, Union
from unittest import mock from unittest import mock
import pytest import pytest
@ -170,6 +170,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
llm_or_chain: Union[BaseLanguageModel, Chain], llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int, n_repetitions: int,
tracer: Any, tracer: Any,
tags: Optional[List[str]] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
return [ return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions) {"result": f"Result for example {example.id}"} for _ in range(n_repetitions)