diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 440336f2..8a06c1d8 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -39,6 +39,8 @@ class BaseChatModel(BaseLanguageModel, ABC): """Whether to print out response text.""" callbacks: Callbacks = Field(default=None, exclude=True) callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) + tags: Optional[List[str]] = Field(default=None, exclude=True) + """Tags to add to the run trace.""" @root_validator() def raise_deprecation(cls, values: Dict) -> Dict: @@ -65,6 +67,8 @@ class BaseChatModel(BaseLanguageModel, ABC): messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> LLMResult: """Top Level call""" @@ -74,7 +78,11 @@ class BaseChatModel(BaseLanguageModel, ABC): options = {"stop": stop} callback_manager = CallbackManager.configure( - callbacks, self.callbacks, self.verbose + callbacks, + self.callbacks, + self.verbose, + tags, + self.tags, ) run_manager = callback_manager.on_chat_model_start( dumpd(self), messages, invocation_params=params, options=options @@ -106,6 +114,8 @@ class BaseChatModel(BaseLanguageModel, ABC): messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> LLMResult: """Top Level call""" @@ -114,7 +124,11 @@ class BaseChatModel(BaseLanguageModel, ABC): options = {"stop": stop} callback_manager = AsyncCallbackManager.configure( - callbacks, self.callbacks, self.verbose + callbacks, + self.callbacks, + self.verbose, + tags, + self.tags, ) run_manager = await callback_manager.on_chat_model_start( dumpd(self), messages, invocation_params=params, options=options diff --git a/langchain/client/runner_utils.py b/langchain/client/runner_utils.py index 9dd66481..bb295f39 100644 --- a/langchain/client/runner_utils.py +++ b/langchain/client/runner_utils.py @@ -5,7 +5,16 @@ import asyncio import functools import logging from datetime import datetime -from typing import Any, Callable, Coroutine, Dict, Iterator, List, Optional, Union +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Iterator, + List, + Optional, + Union, +) from langchainplus_sdk import LangChainPlusClient from langchainplus_sdk.schemas import Example @@ -104,6 +113,8 @@ async def _arun_llm( llm: BaseLanguageModel, inputs: Dict[str, Any], langchain_tracer: Optional[LangChainTracer], + *, + tags: Optional[List[str]] = None, ) -> Union[LLMResult, ChatResult]: callbacks: Optional[List[BaseCallbackHandler]] = ( [langchain_tracer] if langchain_tracer else None @@ -111,21 +122,27 @@ async def _arun_llm( if isinstance(llm, BaseLLM): try: llm_prompts = _get_prompts(inputs) - llm_output = await llm.agenerate(llm_prompts, callbacks=callbacks) + llm_output = await llm.agenerate( + llm_prompts, callbacks=callbacks, tags=tags + ) except InputFormatError: llm_messages = _get_messages(inputs) buffer_strings = [get_buffer_string(messages) for messages in llm_messages] - llm_output = await llm.agenerate(buffer_strings, callbacks=callbacks) + llm_output = await llm.agenerate( + buffer_strings, callbacks=callbacks, tags=tags + ) elif isinstance(llm, BaseChatModel): try: messages = _get_messages(inputs) - llm_output = await llm.agenerate(messages, callbacks=callbacks) + llm_output = await llm.agenerate(messages, callbacks=callbacks, tags=tags) except InputFormatError: prompts = _get_prompts(inputs) converted_messages: List[List[BaseMessage]] = [ [HumanMessage(content=prompt)] for prompt in prompts ] - llm_output = await llm.agenerate(converted_messages, callbacks=callbacks) + llm_output = await llm.agenerate( + converted_messages, callbacks=callbacks, tags=tags + ) else: raise ValueError(f"Unsupported LLM type {type(llm)}") return llm_output @@ -136,6 +153,8 @@ async def _arun_llm_or_chain( llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, n_repetitions: int, langchain_tracer: Optional[LangChainTracer], + *, + tags: Optional[List[str]] = None, ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: """Run the chain asynchronously.""" if langchain_tracer is not None: @@ -150,11 +169,16 @@ async def _arun_llm_or_chain( try: if isinstance(llm_or_chain_factory, BaseLanguageModel): output: Any = await _arun_llm( - llm_or_chain_factory, example.inputs, langchain_tracer + llm_or_chain_factory, + example.inputs, + langchain_tracer, + tags=tags, ) else: chain = llm_or_chain_factory() - output = await chain.acall(example.inputs, callbacks=callbacks) + output = await chain.acall( + example.inputs, callbacks=callbacks, tags=tags + ) outputs.append(output) except Exception as e: logger.warning(f"Chain failed for example {example.id}. Error: {e}") @@ -230,6 +254,7 @@ async def arun_on_examples( num_repetitions: int = 1, session_name: Optional[str] = None, verbose: bool = False, + tags: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Run the chain on examples and store traces to the specified session name. @@ -245,6 +270,7 @@ async def arun_on_examples( intervals. session_name: Session name to use when tracing runs. verbose: Whether to print progress. + tags: Tags to add to the traces. Returns: A dictionary mapping example ids to the model outputs. @@ -260,6 +286,7 @@ async def arun_on_examples( llm_or_chain_factory, num_repetitions, tracer, + tags=tags, ) results[str(example.id)] = result job_state["num_processed"] += 1 @@ -282,12 +309,14 @@ def run_llm( llm: BaseLanguageModel, inputs: Dict[str, Any], callbacks: Callbacks, + *, + tags: Optional[List[str]] = None, ) -> Union[LLMResult, ChatResult]: """Run the language model on the example.""" if isinstance(llm, BaseLLM): try: llm_prompts = _get_prompts(inputs) - llm_output = llm.generate(llm_prompts, callbacks=callbacks) + llm_output = llm.generate(llm_prompts, callbacks=callbacks, tags=tags) except InputFormatError: llm_messages = _get_messages(inputs) buffer_strings = [get_buffer_string(messages) for messages in llm_messages] @@ -295,13 +324,15 @@ def run_llm( elif isinstance(llm, BaseChatModel): try: messages = _get_messages(inputs) - llm_output = llm.generate(messages, callbacks=callbacks) + llm_output = llm.generate(messages, callbacks=callbacks, tags=tags) except InputFormatError: prompts = _get_prompts(inputs) converted_messages: List[List[BaseMessage]] = [ [HumanMessage(content=prompt)] for prompt in prompts ] - llm_output = llm.generate(converted_messages, callbacks=callbacks) + llm_output = llm.generate( + converted_messages, callbacks=callbacks, tags=tags + ) else: raise ValueError(f"Unsupported LLM type {type(llm)}") return llm_output @@ -312,6 +343,8 @@ def run_llm_or_chain( llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, n_repetitions: int, langchain_tracer: Optional[LangChainTracer] = None, + *, + tags: Optional[List[str]] = None, ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: """Run the chain synchronously.""" if langchain_tracer is not None: @@ -325,10 +358,12 @@ def run_llm_or_chain( for _ in range(n_repetitions): try: if isinstance(llm_or_chain_factory, BaseLanguageModel): - output: Any = run_llm(llm_or_chain_factory, example.inputs, callbacks) + output: Any = run_llm( + llm_or_chain_factory, example.inputs, callbacks, tags=tags + ) else: chain = llm_or_chain_factory() - output = chain(example.inputs, callbacks=callbacks) + output = chain(example.inputs, callbacks=callbacks, tags=tags) outputs.append(output) except Exception as e: logger.warning(f"Chain failed for example {example.id}. Error: {e}") @@ -345,6 +380,7 @@ def run_on_examples( num_repetitions: int = 1, session_name: Optional[str] = None, verbose: bool = False, + tags: Optional[List[str]] = None, ) -> Dict[str, Any]: """Run the chain on examples and store traces to the specified session name. @@ -359,6 +395,7 @@ def run_on_examples( intervals. session_name: Session name to use when tracing runs. verbose: Whether to print progress. + tags: Tags to add to the run traces. Returns: A dictionary mapping example ids to the model outputs. """ @@ -370,6 +407,7 @@ def run_on_examples( llm_or_chain_factory, num_repetitions, langchain_tracer=tracer, + tags=tags, ) if verbose: print(f"{i+1} processed", flush=True, end="\r") @@ -401,6 +439,7 @@ async def arun_on_dataset( session_name: Optional[str] = None, verbose: bool = False, client: Optional[LangChainPlusClient] = None, + tags: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Run the chain on a dataset and store traces to the specified session name. @@ -420,6 +459,7 @@ async def arun_on_dataset( verbose: Whether to print progress. client: Client to use to read the dataset. If not provided, a new client will be created using the credentials in the environment. + tags: Tags to add to each run in the sesssion. Returns: A dictionary containing the run's session name and the resulting model outputs. @@ -436,6 +476,7 @@ async def arun_on_dataset( num_repetitions=num_repetitions, session_name=session_name, verbose=verbose, + tags=tags, ) return { "session_name": session_name, @@ -451,6 +492,7 @@ def run_on_dataset( session_name: Optional[str] = None, verbose: bool = False, client: Optional[LangChainPlusClient] = None, + tags: Optional[List[str]] = None, ) -> Dict[str, Any]: """Run the chain on a dataset and store traces to the specified session name. @@ -468,6 +510,7 @@ def run_on_dataset( verbose: Whether to print progress. client: Client to use to access the dataset. If None, a new client will be created using the credentials in the environment. + tags: Tags to add to each run in the sesssion. Returns: A dictionary containing the run's session name and the resulting model outputs. @@ -482,6 +525,7 @@ def run_on_dataset( num_repetitions=num_repetitions, session_name=session_name, verbose=verbose, + tags=tags, ) return { "session_name": session_name, diff --git a/langchain/experimental/client/tracing_datasets.ipynb b/langchain/experimental/client/tracing_datasets.ipynb index d6eda66c..77436dba 100644 --- a/langchain/experimental/client/tracing_datasets.ipynb +++ b/langchain/experimental/client/tracing_datasets.ipynb @@ -369,6 +369,7 @@ "\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mclient\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[LangChainPlusClient]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[List[str]]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mDocstring:\u001b[0m\n", "Run the chain on a dataset and store traces to the specified session name.\n", @@ -388,6 +389,7 @@ " verbose: Whether to print progress.\n", " client: Client to use to read the dataset. If not provided, a new\n", " client will be created using the credentials in the environment.\n", + " tags: Tags to add to each run in the sesssion.\n", "\n", "Returns:\n", " A dictionary containing the run's session name and the resulting model outputs.\n", @@ -430,7 +432,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33", "metadata": { "tags": [] @@ -440,21 +442,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Processed examples: 4\r" + "Processed examples: 1\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Chain failed for example c855f923-4165-4fe0-a909-360749f3f764. Error: Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n" + "Chain failed for example b36a82d3-4fb6-4bc4-87df-b7c355742b8e. Error: unknown format from LLM: Sorry, I cannot answer this question as it requires information that is not currently available.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Processed examples: 5\r" + "Processed examples: 6\r" ] } ], @@ -465,6 +467,7 @@ " concurrency_level=5, # Optional, sets the number of examples to run at a time\n", " verbose=True,\n", " client=client,\n", + " tags=[\"testing-notebook\", \"turbo\"], # Optional, adds a tag to the resulting chain runs\n", ")\n", "\n", "# Sometimes, the agent will error due to parsing issues, incompatible tool inputs, etc.\n", @@ -486,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "136db492-d6ca-4215-96f9-439c23538241", "metadata": { "tags": [] @@ -501,7 +504,7 @@ "LangChainPlusClient (API URL: https://dev.api.langchain.plus)" ] }, - "execution_count": 13, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -534,7 +537,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "35db4025-9183-4e5f-ba14-0b1b380f49c7", "metadata": { "tags": [] @@ -565,7 +568,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9989f6507cd04ea7a09ea3c5723dc984", + "model_id": "5fce1ce42a8c4110b7d12443948ac697", "version_major": 2, "version_minor": 0 }, @@ -592,12 +595,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "8696f167-dc75-4ef8-8bb3-ac1ce8324f30", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "LangChain+ Client" + ], + "text/plain": [ + "LangChainPlusClient (API URL: https://dev.api.langchain.plus)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "client" ] diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 9b065142..2647635d 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -79,6 +79,8 @@ class BaseLLM(BaseLanguageModel, ABC): """Whether to print out response text.""" callbacks: Callbacks = Field(default=None, exclude=True) callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) + tags: Optional[List[str]] = Field(default=None, exclude=True) + """Tags to add to the run trace.""" class Config: """Configuration for this pydantic object.""" @@ -155,6 +157,8 @@ class BaseLLM(BaseLanguageModel, ABC): prompts: List[str], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" @@ -176,7 +180,7 @@ class BaseLLM(BaseLanguageModel, ABC): ) = get_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache callback_manager = CallbackManager.configure( - callbacks, self.callbacks, self.verbose + callbacks, self.callbacks, self.verbose, tags, self.tags ) new_arg_supported = inspect.signature(self._generate).parameters.get( "run_manager" @@ -241,6 +245,8 @@ class BaseLLM(BaseLanguageModel, ABC): prompts: List[str], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" @@ -255,7 +261,7 @@ class BaseLLM(BaseLanguageModel, ABC): ) = get_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache callback_manager = AsyncCallbackManager.configure( - callbacks, self.callbacks, self.verbose + callbacks, self.callbacks, self.verbose, tags, self.tags ) new_arg_supported = inspect.signature(self._agenerate).parameters.get( "run_manager" diff --git a/tests/unit_tests/client/test_runner_utils.py b/tests/unit_tests/client/test_runner_utils.py index 84973d34..4487657e 100644 --- a/tests/unit_tests/client/test_runner_utils.py +++ b/tests/unit_tests/client/test_runner_utils.py @@ -1,7 +1,7 @@ """Test the LangChain+ client.""" import uuid from datetime import datetime -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from unittest import mock import pytest @@ -170,6 +170,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: llm_or_chain: Union[BaseLanguageModel, Chain], n_repetitions: int, tracer: Any, + tags: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: return [ {"result": f"Result for example {example.id}"} for _ in range(n_repetitions)