diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 333a1a62..2bdf22ff 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -124,6 +124,7 @@ class CallbackManagerMixin: *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> Any: """Run when LLM starts running.""" @@ -135,6 +136,7 @@ class CallbackManagerMixin: *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> Any: """Run when a chat model starts running.""" @@ -149,6 +151,7 @@ class CallbackManagerMixin: *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> Any: """Run when chain starts running.""" @@ -160,6 +163,7 @@ class CallbackManagerMixin: *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> Any: """Run when tool starts running.""" @@ -221,6 +225,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when LLM starts running.""" @@ -232,6 +237,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> Any: """Run when a chat model starts running.""" @@ -276,6 +282,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when chain starts running.""" @@ -307,6 +314,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when tool starts running.""" @@ -370,6 +378,9 @@ class BaseCallbackManager(CallbackManagerMixin): handlers: List[BaseCallbackHandler], inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, parent_run_id: Optional[UUID] = None, + *, + tags: Optional[List[str]] = None, + inheritable_tags: Optional[List[str]] = None, ) -> None: """Initialize callback manager.""" self.handlers: List[BaseCallbackHandler] = handlers @@ -377,6 +388,8 @@ class BaseCallbackManager(CallbackManagerMixin): inheritable_handlers or [] ) self.parent_run_id: Optional[UUID] = parent_run_id + self.tags = tags or [] + self.inheritable_tags = inheritable_tags or [] @property def is_async(self) -> bool: @@ -406,3 +419,16 @@ class BaseCallbackManager(CallbackManagerMixin): def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: """Set handler as the only handler on the callback manager.""" self.set_handlers([handler], inherit=inherit) + + def add_tags(self, tags: List[str], inherit: bool = True) -> None: + for tag in tags: + if tag in self.tags: + self.remove_tags([tag]) + self.tags.extend(tags) + if inherit: + self.inheritable_tags.extend(tags) + + def remove_tags(self, tags: List[str]) -> None: + for tag in tags: + self.tags.remove(tag) + self.inheritable_tags.remove(tag) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 07600bf9..c407cc5c 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -269,21 +269,32 @@ class BaseRunManager(RunManagerMixin): def __init__( self, + *, run_id: UUID, handlers: List[BaseCallbackHandler], inheritable_handlers: List[BaseCallbackHandler], parent_run_id: Optional[UUID] = None, + tags: List[str], + inheritable_tags: List[str], ) -> None: """Initialize run manager.""" self.run_id = run_id self.handlers = handlers self.inheritable_handlers = inheritable_handlers + self.tags = tags + self.inheritable_tags = inheritable_tags self.parent_run_id = parent_run_id @classmethod def get_noop_manager(cls: Type[BRM]) -> BRM: """Return a manager that doesn't perform any operations.""" - return cls(uuid4(), [], []) + return cls( + run_id=uuid4(), + handlers=[], + inheritable_handlers=[], + tags=[], + inheritable_tags=[], + ) class RunManager(BaseRunManager): @@ -425,10 +436,13 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): class CallbackManagerForChainRun(RunManager, ChainManagerMixin): """Callback manager for chain run.""" - def get_child(self) -> CallbackManager: + def get_child(self, tag: Optional[str] = None) -> CallbackManager: """Get a child callback manager.""" - manager = CallbackManager([], parent_run_id=self.run_id) + manager = CallbackManager(handlers=[], parent_run_id=self.run_id) manager.set_handlers(self.inheritable_handlers) + manager.add_tags(self.inheritable_tags) + if tag is not None: + manager.add_tags([tag], False) return manager def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: @@ -487,10 +501,13 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin): class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): """Async callback manager for chain run.""" - def get_child(self) -> AsyncCallbackManager: + def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: """Get a child callback manager.""" - manager = AsyncCallbackManager([], parent_run_id=self.run_id) + manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id) manager.set_handlers(self.inheritable_handlers) + manager.add_tags(self.inheritable_tags) + if tag is not None: + manager.add_tags([tag], False) return manager async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: @@ -549,10 +566,13 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): class CallbackManagerForToolRun(RunManager, ToolManagerMixin): """Callback manager for tool run.""" - def get_child(self) -> CallbackManager: + def get_child(self, tag: Optional[str] = None) -> CallbackManager: """Get a child callback manager.""" - manager = CallbackManager([], parent_run_id=self.run_id) + manager = CallbackManager(handlers=[], parent_run_id=self.run_id) manager.set_handlers(self.inheritable_handlers) + manager.add_tags(self.inheritable_tags) + if tag is not None: + manager.add_tags([tag], False) return manager def on_tool_end( @@ -591,10 +611,13 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin): class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin): """Async callback manager for tool run.""" - def get_child(self) -> AsyncCallbackManager: + def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: """Get a child callback manager.""" - manager = AsyncCallbackManager([], parent_run_id=self.run_id) + manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id) manager.set_handlers(self.inheritable_handlers) + manager.add_tags(self.inheritable_tags) + if tag is not None: + manager.add_tags([tag], False) return manager async def on_tool_end(self, output: str, **kwargs: Any) -> None: @@ -648,11 +671,17 @@ class CallbackManager(BaseCallbackManager): prompts, run_id=run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) return CallbackManagerForLLMRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, ) def on_chat_model_start( @@ -673,13 +702,19 @@ class CallbackManager(BaseCallbackManager): messages, run_id=run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) # Re-use the LLM Run Manager since the outputs are treated # the same for now return CallbackManagerForLLMRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, ) def on_chain_start( @@ -701,11 +736,17 @@ class CallbackManager(BaseCallbackManager): inputs, run_id=run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) return CallbackManagerForChainRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, ) def on_tool_start( @@ -728,11 +769,17 @@ class CallbackManager(BaseCallbackManager): input_str, run_id=run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) return CallbackManagerForToolRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, ) @classmethod @@ -741,9 +788,18 @@ class CallbackManager(BaseCallbackManager): inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, verbose: bool = False, + inheritable_tags: Optional[List[str]] = None, + local_tags: Optional[List[str]] = None, ) -> CallbackManager: """Configure the callback manager.""" - return _configure(cls, inheritable_callbacks, local_callbacks, verbose) + return _configure( + cls, + inheritable_callbacks, + local_callbacks, + verbose, + inheritable_tags, + local_tags, + ) class AsyncCallbackManager(BaseCallbackManager): @@ -773,11 +829,17 @@ class AsyncCallbackManager(BaseCallbackManager): prompts, run_id=run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) return AsyncCallbackManagerForLLMRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, ) async def on_chat_model_start( @@ -798,11 +860,17 @@ class AsyncCallbackManager(BaseCallbackManager): messages, run_id=run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) return AsyncCallbackManagerForLLMRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, ) async def on_chain_start( @@ -824,11 +892,17 @@ class AsyncCallbackManager(BaseCallbackManager): inputs, run_id=run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) return AsyncCallbackManagerForChainRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, ) async def on_tool_start( @@ -851,11 +925,17 @@ class AsyncCallbackManager(BaseCallbackManager): input_str, run_id=run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) return AsyncCallbackManagerForToolRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, ) @classmethod @@ -864,9 +944,18 @@ class AsyncCallbackManager(BaseCallbackManager): inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, verbose: bool = False, + inheritable_tags: Optional[List[str]] = None, + local_tags: Optional[List[str]] = None, ) -> AsyncCallbackManager: """Configure the callback manager.""" - return _configure(cls, inheritable_callbacks, local_callbacks, verbose) + return _configure( + cls, + inheritable_callbacks, + local_callbacks, + verbose, + inheritable_tags, + local_tags, + ) T = TypeVar("T", CallbackManager, AsyncCallbackManager) @@ -887,9 +976,11 @@ def _configure( inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, verbose: bool = False, + inheritable_tags: Optional[List[str]] = None, + local_tags: Optional[List[str]] = None, ) -> T: """Configure the callback manager.""" - callback_manager = callback_manager_cls([]) + callback_manager = callback_manager_cls(handlers=[]) if inheritable_callbacks or local_callbacks: if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: inheritable_callbacks_ = inheritable_callbacks or [] @@ -902,6 +993,8 @@ def _configure( handlers=inheritable_callbacks.handlers, inheritable_handlers=inheritable_callbacks.inheritable_handlers, parent_run_id=inheritable_callbacks.parent_run_id, + tags=inheritable_callbacks.tags, + inheritable_tags=inheritable_callbacks.inheritable_tags, ) local_handlers_ = ( local_callbacks @@ -910,6 +1003,9 @@ def _configure( ) for handler in local_handlers_: callback_manager.add_handler(handler, False) + if inheritable_tags or local_tags: + callback_manager.add_tags(inheritable_tags or []) + callback_manager.add_tags(local_tags or [], False) tracer = tracing_callback_var.get() wandb_tracer = wandb_tracing_callback_var.get() diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 93df3513..dd0c1183 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -85,6 +85,7 @@ class BaseTracer(BaseCallbackHandler, ABC): prompts: List[str], *, run_id: UUID, + tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: @@ -101,6 +102,7 @@ class BaseTracer(BaseCallbackHandler, ABC): execution_order=execution_order, child_execution_order=execution_order, run_type=RunTypeEnum.llm, + tags=tags or [], ) self._start_trace(llm_run) self._on_llm_start(llm_run) @@ -145,6 +147,7 @@ class BaseTracer(BaseCallbackHandler, ABC): inputs: Dict[str, Any], *, run_id: UUID, + tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: @@ -162,6 +165,7 @@ class BaseTracer(BaseCallbackHandler, ABC): child_execution_order=execution_order, child_runs=[], run_type=RunTypeEnum.chain, + tags=tags or [], ) self._start_trace(chain_run) self._on_chain_start(chain_run) @@ -206,6 +210,7 @@ class BaseTracer(BaseCallbackHandler, ABC): input_str: str, *, run_id: UUID, + tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: @@ -223,6 +228,7 @@ class BaseTracer(BaseCallbackHandler, ABC): child_execution_order=execution_order, child_runs=[], run_type=RunTypeEnum.tool, + tags=tags or [], ) self._start_trace(tool_run) self._on_tool_start(tool_run) diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 2dabfcde..ecb3dbe5 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -59,6 +59,7 @@ class LangChainTracer(BaseTracer): messages: List[List[BaseMessage]], *, run_id: UUID, + tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: @@ -75,6 +76,7 @@ class LangChainTracer(BaseTracer): execution_order=execution_order, child_execution_order=execution_order, run_type=RunTypeEnum.llm, + tags=tags, ) self._start_trace(chat_model_run) self._on_chat_model_start(chat_model_run) diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index 1e264e7d..c0218c83 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -94,6 +94,7 @@ class Run(BaseRunV2): execution_order: int child_execution_order: int child_runs: List[Run] = Field(default_factory=list) + tags: Optional[List[str]] = Field(default_factory=list) @root_validator(pre=True) def assign_name(cls, values: dict) -> dict: diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 66354adc..dfa1b23a 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -36,6 +36,7 @@ class Chain(Serializable, ABC): verbose: bool = Field( default_factory=_get_verbosity ) # Whether to print the response text + tags: Optional[List[str]] = None class Config: """Configuration for this pydantic object.""" @@ -111,6 +112,7 @@ class Chain(Serializable, ABC): return_only_outputs: bool = False, callbacks: Callbacks = None, *, + tags: Optional[List[str]] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -129,7 +131,7 @@ class Chain(Serializable, ABC): """ inputs = self.prep_inputs(inputs) callback_manager = CallbackManager.configure( - callbacks, self.callbacks, self.verbose + callbacks, self.callbacks, self.verbose, tags, self.tags ) new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") run_manager = callback_manager.on_chain_start( @@ -159,6 +161,7 @@ class Chain(Serializable, ABC): return_only_outputs: bool = False, callbacks: Callbacks = None, *, + tags: Optional[List[str]] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -177,7 +180,7 @@ class Chain(Serializable, ABC): """ inputs = self.prep_inputs(inputs) callback_manager = AsyncCallbackManager.configure( - callbacks, self.callbacks, self.verbose + callbacks, self.callbacks, self.verbose, tags, self.tags ) new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") run_manager = await callback_manager.on_chain_start( @@ -244,7 +247,13 @@ class Chain(Serializable, ABC): """Call the chain on all inputs in the list.""" return [self(inputs, callbacks=callbacks) for inputs in input_list] - def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str: + def run( + self, + *args: Any, + callbacks: Callbacks = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> str: """Run the chain as text in, text out or multiple variables, text out.""" if len(self.output_keys) != 1: raise ValueError( @@ -255,10 +264,10 @@ class Chain(Serializable, ABC): if args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") - return self(args[0], callbacks=callbacks)[self.output_keys[0]] + return self(args[0], callbacks=callbacks, tags=tags)[self.output_keys[0]] if kwargs and not args: - return self(kwargs, callbacks=callbacks)[self.output_keys[0]] + return self(kwargs, callbacks=callbacks, tags=tags)[self.output_keys[0]] if not kwargs and not args: raise ValueError( @@ -271,7 +280,13 @@ class Chain(Serializable, ABC): f" but not both. Got args: {args} and kwargs: {kwargs}." ) - async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str: + async def arun( + self, + *args: Any, + callbacks: Callbacks = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> str: """Run the chain as text in, text out or multiple variables, text out.""" if len(self.output_keys) != 1: raise ValueError( @@ -282,10 +297,14 @@ class Chain(Serializable, ABC): if args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") - return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]] + return (await self.acall(args[0], callbacks=callbacks, tags=tags))[ + self.output_keys[0] + ] if kwargs and not args: - return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]] + return (await self.acall(kwargs, callbacks=callbacks, tags=tags))[ + self.output_keys[0] + ] raise ValueError( f"`run` supported with either positional arguments or keyword arguments" diff --git a/langchain/chains/constitutional_ai/base.py b/langchain/chains/constitutional_ai/base.py index 525b2b3e..bf342120 100644 --- a/langchain/chains/constitutional_ai/base.py +++ b/langchain/chains/constitutional_ai/base.py @@ -98,7 +98,7 @@ class ConstitutionalChain(Chain): _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() response = self.chain.run( **inputs, - callbacks=_run_manager.get_child(), + callbacks=_run_manager.get_child("original"), ) initial_response = response input_prompt = self.chain.prompt.format(**inputs) @@ -116,7 +116,7 @@ class ConstitutionalChain(Chain): input_prompt=input_prompt, output_from_model=response, critique_request=constitutional_principle.critique_request, - callbacks=_run_manager.get_child(), + callbacks=_run_manager.get_child("critique"), ) critique = self._parse_critique( output_string=raw_critique, @@ -137,7 +137,7 @@ class ConstitutionalChain(Chain): critique_request=constitutional_principle.critique_request, critique=critique, revision_request=constitutional_principle.revision_request, - callbacks=_run_manager.get_child(), + callbacks=_run_manager.get_child("revision"), ).strip() response = revision critiques_and_revisions.append((critique, revision)) diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 4c743530..f8154c8c 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -283,7 +283,7 @@ class LLMChain(Chain): return "llm_chain" @classmethod - def from_string(cls, llm: BaseLanguageModel, template: str) -> Chain: + def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain: """Create LLMChain from LLM and template.""" prompt_template = PromptTemplate.from_template(template) return cls(llm=llm, prompt=prompt_template) diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index f94b5bc5..6877df2b 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -174,7 +174,7 @@ class SimpleSequentialChain(Chain): _input = inputs[self.input_key] color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) for i, chain in enumerate(self.chains): - _input = chain.run(_input, callbacks=_run_manager.get_child()) + _input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}")) if self.strip_outputs: _input = _input.strip() _run_manager.on_text( diff --git a/tests/integration_tests/callbacks/test_langchain_tracer.py b/tests/integration_tests/callbacks/test_langchain_tracer.py index 80d18713..5415b050 100644 --- a/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -13,6 +13,8 @@ from langchain.callbacks.manager import ( tracing_v2_enabled, ) from langchain.chains import LLMChain +from langchain.chains.constitutional_ai.base import ConstitutionalChain +from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI from langchain.prompts import PromptTemplate @@ -160,6 +162,25 @@ def test_tracing_v2_context_manager() -> None: agent.run(questions[0]) # this should not be traced +def test_tracing_v2_chain_with_tags() -> None: + llm = OpenAI(temperature=0) + chain = ConstitutionalChain.from_llm( + llm, + chain=LLMChain.from_string(llm, "Q: {question} A:"), + tags=["only-root"], + constitutional_principles=[ + ConstitutionalPrinciple( + critique_request="Tell if this answer is good.", + revision_request="Give a better answer.", + ) + ], + ) + if "LANGCHAIN_TRACING_V2" in os.environ: + del os.environ["LANGCHAIN_TRACING_V2"] + with tracing_v2_enabled(): + chain.run("what is the meaning of life", tags=["a-tag"]) + + def test_trace_as_group() -> None: llm = OpenAI(temperature=0.9) prompt = PromptTemplate( diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 2fb52165..24877b81 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -86,7 +86,7 @@ def test_callback_manager() -> None: """Test the CallbackManager.""" handler1 = FakeCallbackHandler() handler2 = FakeCallbackHandler() - manager = CallbackManager([handler1, handler2]) + manager = CallbackManager(handlers=[handler1, handler2]) _test_callback_manager(manager, handler1, handler2) @@ -143,7 +143,7 @@ async def test_async_callback_manager() -> None: """Test the AsyncCallbackManager.""" handler1 = FakeAsyncCallbackHandler() handler2 = FakeAsyncCallbackHandler() - manager = AsyncCallbackManager([handler1, handler2]) + manager = AsyncCallbackManager(handlers=[handler1, handler2]) await _test_callback_manager_async(manager, handler1, handler2) @@ -153,7 +153,7 @@ async def test_async_callback_manager_sync_handler() -> None: handler1 = FakeCallbackHandler() handler2 = FakeAsyncCallbackHandler() handler3 = FakeAsyncCallbackHandler() - manager = AsyncCallbackManager([handler1, handler2, handler3]) + manager = AsyncCallbackManager(handlers=[handler1, handler2, handler3]) await _test_callback_manager_async(manager, handler1, handler2, handler3) @@ -165,11 +165,11 @@ def test_callback_manager_inheritance() -> None: FakeCallbackHandler(), ) - callback_manager1 = CallbackManager([handler1, handler2]) + callback_manager1 = CallbackManager(handlers=[handler1, handler2]) assert callback_manager1.handlers == [handler1, handler2] assert callback_manager1.inheritable_handlers == [] - callback_manager2 = CallbackManager([]) + callback_manager2 = CallbackManager(handlers=[]) assert callback_manager2.handlers == [] assert callback_manager2.inheritable_handlers == [] @@ -229,7 +229,7 @@ def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None: assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler) assert isinstance(configured_manager, CallbackManager) - async_local_callbacks = AsyncCallbackManager([handler3, handler4]) + async_local_callbacks = AsyncCallbackManager(handlers=[handler3, handler4]) async_configured_manager = AsyncCallbackManager.configure( inheritable_callbacks=inheritable_callbacks, local_callbacks=async_local_callbacks,