Add Tags for LLMs (#6229)

- [x] Add tracing tags to LLMs + Chat Models (both inheritable and
local)
- [x] Add tags for the run_on_dataset helper function(s)
searx_updates
Zander Chase 11 months ago committed by GitHub
parent 8e1a7a8646
commit ae76e473e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,6 +39,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
@ -65,6 +67,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
@ -74,7 +78,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
options = {"stop": stop}
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
)
run_manager = callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options
@ -106,6 +114,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
@ -114,7 +124,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
)
run_manager = await callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options

@ -5,7 +5,16 @@ import asyncio
import functools
import logging
from datetime import datetime
from typing import Any, Callable, Coroutine, Dict, Iterator, List, Optional, Union
from typing import (
Any,
Callable,
Coroutine,
Dict,
Iterator,
List,
Optional,
Union,
)
from langchainplus_sdk import LangChainPlusClient
from langchainplus_sdk.schemas import Example
@ -104,6 +113,8 @@ async def _arun_llm(
llm: BaseLanguageModel,
inputs: Dict[str, Any],
langchain_tracer: Optional[LangChainTracer],
*,
tags: Optional[List[str]] = None,
) -> Union[LLMResult, ChatResult]:
callbacks: Optional[List[BaseCallbackHandler]] = (
[langchain_tracer] if langchain_tracer else None
@ -111,21 +122,27 @@ async def _arun_llm(
if isinstance(llm, BaseLLM):
try:
llm_prompts = _get_prompts(inputs)
llm_output = await llm.agenerate(llm_prompts, callbacks=callbacks)
llm_output = await llm.agenerate(
llm_prompts, callbacks=callbacks, tags=tags
)
except InputFormatError:
llm_messages = _get_messages(inputs)
buffer_strings = [get_buffer_string(messages) for messages in llm_messages]
llm_output = await llm.agenerate(buffer_strings, callbacks=callbacks)
llm_output = await llm.agenerate(
buffer_strings, callbacks=callbacks, tags=tags
)
elif isinstance(llm, BaseChatModel):
try:
messages = _get_messages(inputs)
llm_output = await llm.agenerate(messages, callbacks=callbacks)
llm_output = await llm.agenerate(messages, callbacks=callbacks, tags=tags)
except InputFormatError:
prompts = _get_prompts(inputs)
converted_messages: List[List[BaseMessage]] = [
[HumanMessage(content=prompt)] for prompt in prompts
]
llm_output = await llm.agenerate(converted_messages, callbacks=callbacks)
llm_output = await llm.agenerate(
converted_messages, callbacks=callbacks, tags=tags
)
else:
raise ValueError(f"Unsupported LLM type {type(llm)}")
return llm_output
@ -136,6 +153,8 @@ async def _arun_llm_or_chain(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int,
langchain_tracer: Optional[LangChainTracer],
*,
tags: Optional[List[str]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain asynchronously."""
if langchain_tracer is not None:
@ -150,11 +169,16 @@ async def _arun_llm_or_chain(
try:
if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = await _arun_llm(
llm_or_chain_factory, example.inputs, langchain_tracer
llm_or_chain_factory,
example.inputs,
langchain_tracer,
tags=tags,
)
else:
chain = llm_or_chain_factory()
output = await chain.acall(example.inputs, callbacks=callbacks)
output = await chain.acall(
example.inputs, callbacks=callbacks, tags=tags
)
outputs.append(output)
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
@ -230,6 +254,7 @@ async def arun_on_examples(
num_repetitions: int = 1,
session_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Run the chain on examples and store traces to the specified session name.
@ -245,6 +270,7 @@ async def arun_on_examples(
intervals.
session_name: Session name to use when tracing runs.
verbose: Whether to print progress.
tags: Tags to add to the traces.
Returns:
A dictionary mapping example ids to the model outputs.
@ -260,6 +286,7 @@ async def arun_on_examples(
llm_or_chain_factory,
num_repetitions,
tracer,
tags=tags,
)
results[str(example.id)] = result
job_state["num_processed"] += 1
@ -282,12 +309,14 @@ def run_llm(
llm: BaseLanguageModel,
inputs: Dict[str, Any],
callbacks: Callbacks,
*,
tags: Optional[List[str]] = None,
) -> Union[LLMResult, ChatResult]:
"""Run the language model on the example."""
if isinstance(llm, BaseLLM):
try:
llm_prompts = _get_prompts(inputs)
llm_output = llm.generate(llm_prompts, callbacks=callbacks)
llm_output = llm.generate(llm_prompts, callbacks=callbacks, tags=tags)
except InputFormatError:
llm_messages = _get_messages(inputs)
buffer_strings = [get_buffer_string(messages) for messages in llm_messages]
@ -295,13 +324,15 @@ def run_llm(
elif isinstance(llm, BaseChatModel):
try:
messages = _get_messages(inputs)
llm_output = llm.generate(messages, callbacks=callbacks)
llm_output = llm.generate(messages, callbacks=callbacks, tags=tags)
except InputFormatError:
prompts = _get_prompts(inputs)
converted_messages: List[List[BaseMessage]] = [
[HumanMessage(content=prompt)] for prompt in prompts
]
llm_output = llm.generate(converted_messages, callbacks=callbacks)
llm_output = llm.generate(
converted_messages, callbacks=callbacks, tags=tags
)
else:
raise ValueError(f"Unsupported LLM type {type(llm)}")
return llm_output
@ -312,6 +343,8 @@ def run_llm_or_chain(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int,
langchain_tracer: Optional[LangChainTracer] = None,
*,
tags: Optional[List[str]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain synchronously."""
if langchain_tracer is not None:
@ -325,10 +358,12 @@ def run_llm_or_chain(
for _ in range(n_repetitions):
try:
if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = run_llm(llm_or_chain_factory, example.inputs, callbacks)
output: Any = run_llm(
llm_or_chain_factory, example.inputs, callbacks, tags=tags
)
else:
chain = llm_or_chain_factory()
output = chain(example.inputs, callbacks=callbacks)
output = chain(example.inputs, callbacks=callbacks, tags=tags)
outputs.append(output)
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
@ -345,6 +380,7 @@ def run_on_examples(
num_repetitions: int = 1,
session_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Run the chain on examples and store traces to the specified session name.
@ -359,6 +395,7 @@ def run_on_examples(
intervals.
session_name: Session name to use when tracing runs.
verbose: Whether to print progress.
tags: Tags to add to the run traces.
Returns:
A dictionary mapping example ids to the model outputs.
"""
@ -370,6 +407,7 @@ def run_on_examples(
llm_or_chain_factory,
num_repetitions,
langchain_tracer=tracer,
tags=tags,
)
if verbose:
print(f"{i+1} processed", flush=True, end="\r")
@ -401,6 +439,7 @@ async def arun_on_dataset(
session_name: Optional[str] = None,
verbose: bool = False,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Run the chain on a dataset and store traces to the specified session name.
@ -420,6 +459,7 @@ async def arun_on_dataset(
verbose: Whether to print progress.
client: Client to use to read the dataset. If not provided, a new
client will be created using the credentials in the environment.
tags: Tags to add to each run in the sesssion.
Returns:
A dictionary containing the run's session name and the resulting model outputs.
@ -436,6 +476,7 @@ async def arun_on_dataset(
num_repetitions=num_repetitions,
session_name=session_name,
verbose=verbose,
tags=tags,
)
return {
"session_name": session_name,
@ -451,6 +492,7 @@ def run_on_dataset(
session_name: Optional[str] = None,
verbose: bool = False,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Run the chain on a dataset and store traces to the specified session name.
@ -468,6 +510,7 @@ def run_on_dataset(
verbose: Whether to print progress.
client: Client to use to access the dataset. If None, a new client
will be created using the credentials in the environment.
tags: Tags to add to each run in the sesssion.
Returns:
A dictionary containing the run's session name and the resulting model outputs.
@ -482,6 +525,7 @@ def run_on_dataset(
num_repetitions=num_repetitions,
session_name=session_name,
verbose=verbose,
tags=tags,
)
return {
"session_name": session_name,

@ -369,6 +369,7 @@
"\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mclient\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[LangChainPlusClient]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[List[str]]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m\n",
"Run the chain on a dataset and store traces to the specified session name.\n",
@ -388,6 +389,7 @@
" verbose: Whether to print progress.\n",
" client: Client to use to read the dataset. If not provided, a new\n",
" client will be created using the credentials in the environment.\n",
" tags: Tags to add to each run in the sesssion.\n",
"\n",
"Returns:\n",
" A dictionary containing the run's session name and the resulting model outputs.\n",
@ -430,7 +432,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33",
"metadata": {
"tags": []
@ -440,21 +442,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 4\r"
"Processed examples: 1\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Chain failed for example c855f923-4165-4fe0-a909-360749f3f764. Error: Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n"
"Chain failed for example b36a82d3-4fb6-4bc4-87df-b7c355742b8e. Error: unknown format from LLM: Sorry, I cannot answer this question as it requires information that is not currently available.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 5\r"
"Processed examples: 6\r"
]
}
],
@ -465,6 +467,7 @@
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
" verbose=True,\n",
" client=client,\n",
" tags=[\"testing-notebook\", \"turbo\"], # Optional, adds a tag to the resulting chain runs\n",
")\n",
"\n",
"# Sometimes, the agent will error due to parsing issues, incompatible tool inputs, etc.\n",
@ -486,7 +489,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "136db492-d6ca-4215-96f9-439c23538241",
"metadata": {
"tags": []
@ -501,7 +504,7 @@
"LangChainPlusClient (API URL: https://dev.api.langchain.plus)"
]
},
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@ -534,7 +537,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "35db4025-9183-4e5f-ba14-0b1b380f49c7",
"metadata": {
"tags": []
@ -565,7 +568,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9989f6507cd04ea7a09ea3c5723dc984",
"model_id": "5fce1ce42a8c4110b7d12443948ac697",
"version_major": 2,
"version_minor": 0
},
@ -592,12 +595,26 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"id": "8696f167-dc75-4ef8-8bb3-ac1ce8324f30",
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"data": {
"text/html": [
"<a href=\"https://dev.langchain.plus\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
],
"text/plain": [
"LangChainPlusClient (API URL: https://dev.api.langchain.plus)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client"
]

@ -79,6 +79,8 @@ class BaseLLM(BaseLanguageModel, ABC):
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
class Config:
"""Configuration for this pydantic object."""
@ -155,6 +157,8 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
@ -176,7 +180,7 @@ class BaseLLM(BaseLanguageModel, ABC):
) = get_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
callbacks, self.callbacks, self.verbose, tags, self.tags
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
@ -241,6 +245,8 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
@ -255,7 +261,7 @@ class BaseLLM(BaseLanguageModel, ABC):
) = get_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
callbacks, self.callbacks, self.verbose, tags, self.tags
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"

@ -1,7 +1,7 @@
"""Test the LangChain+ client."""
import uuid
from datetime import datetime
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
from unittest import mock
import pytest
@ -170,6 +170,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
tracer: Any,
tags: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)

Loading…
Cancel
Save