Add support for tags (#5898)

<!--
Thank you for contributing to LangChain! Your PR will appear in our
release under the title you set. Please make sure it highlights your
valuable contribution.

Replace this with a description of the change, the issue it fixes (if
applicable), and relevant context. List any dependencies required for
this change.

After you're done, someone will review your PR. They may suggest
improvements. If no one reviews your PR within a few days, feel free to
@-mention the same people again, as notifications can get lost.

Finally, we'd love to show appreciation for your contribution - if you'd
like us to shout you out on Twitter, please also include your handle!
-->

<!-- Remove if not applicable -->

Fixes # (issue)

#### Before submitting

<!-- If you're adding a new integration, please include:

1. a test for the integration - favor unit tests that does not rely on
network access.
2. an example notebook showing its use


See contribution guidelines for more information on how to write tests,
lint
etc:


https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
-->

#### Who can review?

Tag maintainers/contributors who might be interested:

<!-- For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead

  Tracing / Callbacks
  - @agola11

  Async
  - @agola11

  DataLoaders
  - @eyurtsev

  Models
  - @hwchase17
  - @agola11

  Agents / Tools / Toolkits
  - @vowelparrot

  VectorStores / Retrievers / Memory
  - @dev2049

 -->
searx_updates
Nuno Campos 11 months ago committed by GitHub
parent 1281fdf0f2
commit 11ab0be11a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -124,6 +124,7 @@ class CallbackManagerMixin:
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM starts running."""
@ -135,6 +136,7 @@ class CallbackManagerMixin:
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
@ -149,6 +151,7 @@ class CallbackManagerMixin:
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
@ -160,6 +163,7 @@ class CallbackManagerMixin:
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
"""Run when tool starts running."""
@ -221,6 +225,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
@ -232,6 +237,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
@ -276,6 +282,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
@ -307,6 +314,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
@ -370,6 +378,9 @@ class BaseCallbackManager(CallbackManagerMixin):
handlers: List[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None,
*,
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
@ -377,6 +388,8 @@ class BaseCallbackManager(CallbackManagerMixin):
inheritable_handlers or []
)
self.parent_run_id: Optional[UUID] = parent_run_id
self.tags = tags or []
self.inheritable_tags = inheritable_tags or []
@property
def is_async(self) -> bool:
@ -406,3 +419,16 @@ class BaseCallbackManager(CallbackManagerMixin):
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Set handler as the only handler on the callback manager."""
self.set_handlers([handler], inherit=inherit)
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
for tag in tags:
if tag in self.tags:
self.remove_tags([tag])
self.tags.extend(tags)
if inherit:
self.inheritable_tags.extend(tags)
def remove_tags(self, tags: List[str]) -> None:
for tag in tags:
self.tags.remove(tag)
self.inheritable_tags.remove(tag)

@ -269,21 +269,32 @@ class BaseRunManager(RunManagerMixin):
def __init__(
self,
*,
run_id: UUID,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler],
parent_run_id: Optional[UUID] = None,
tags: List[str],
inheritable_tags: List[str],
) -> None:
"""Initialize run manager."""
self.run_id = run_id
self.handlers = handlers
self.inheritable_handlers = inheritable_handlers
self.tags = tags
self.inheritable_tags = inheritable_tags
self.parent_run_id = parent_run_id
@classmethod
def get_noop_manager(cls: Type[BRM]) -> BRM:
"""Return a manager that doesn't perform any operations."""
return cls(uuid4(), [], [])
return cls(
run_id=uuid4(),
handlers=[],
inheritable_handlers=[],
tags=[],
inheritable_tags=[],
)
class RunManager(BaseRunManager):
@ -425,10 +436,13 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
"""Callback manager for chain run."""
def get_child(self) -> CallbackManager:
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id)
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
@ -487,10 +501,13 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
"""Async callback manager for chain run."""
def get_child(self) -> AsyncCallbackManager:
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
@ -549,10 +566,13 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
"""Callback manager for tool run."""
def get_child(self) -> CallbackManager:
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id)
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
def on_tool_end(
@ -591,10 +611,13 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
"""Async callback manager for tool run."""
def get_child(self) -> AsyncCallbackManager:
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
@ -648,11 +671,17 @@ class CallbackManager(BaseCallbackManager):
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
def on_chat_model_start(
@ -673,13 +702,19 @@ class CallbackManager(BaseCallbackManager):
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
# Re-use the LLM Run Manager since the outputs are treated
# the same for now
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
def on_chain_start(
@ -701,11 +736,17 @@ class CallbackManager(BaseCallbackManager):
inputs,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
return CallbackManagerForChainRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
def on_tool_start(
@ -728,11 +769,17 @@ class CallbackManager(BaseCallbackManager):
input_str,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
return CallbackManagerForToolRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
@classmethod
@ -741,9 +788,18 @@ class CallbackManager(BaseCallbackManager):
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
) -> CallbackManager:
"""Configure the callback manager."""
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
return _configure(
cls,
inheritable_callbacks,
local_callbacks,
verbose,
inheritable_tags,
local_tags,
)
class AsyncCallbackManager(BaseCallbackManager):
@ -773,11 +829,17 @@ class AsyncCallbackManager(BaseCallbackManager):
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
async def on_chat_model_start(
@ -798,11 +860,17 @@ class AsyncCallbackManager(BaseCallbackManager):
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
async def on_chain_start(
@ -824,11 +892,17 @@ class AsyncCallbackManager(BaseCallbackManager):
inputs,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
return AsyncCallbackManagerForChainRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
async def on_tool_start(
@ -851,11 +925,17 @@ class AsyncCallbackManager(BaseCallbackManager):
input_str,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
return AsyncCallbackManagerForToolRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
@classmethod
@ -864,9 +944,18 @@ class AsyncCallbackManager(BaseCallbackManager):
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
) -> AsyncCallbackManager:
"""Configure the callback manager."""
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
return _configure(
cls,
inheritable_callbacks,
local_callbacks,
verbose,
inheritable_tags,
local_tags,
)
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
@ -887,9 +976,11 @@ def _configure(
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
) -> T:
"""Configure the callback manager."""
callback_manager = callback_manager_cls([])
callback_manager = callback_manager_cls(handlers=[])
if inheritable_callbacks or local_callbacks:
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
inheritable_callbacks_ = inheritable_callbacks or []
@ -902,6 +993,8 @@ def _configure(
handlers=inheritable_callbacks.handlers,
inheritable_handlers=inheritable_callbacks.inheritable_handlers,
parent_run_id=inheritable_callbacks.parent_run_id,
tags=inheritable_callbacks.tags,
inheritable_tags=inheritable_callbacks.inheritable_tags,
)
local_handlers_ = (
local_callbacks
@ -910,6 +1003,9 @@ def _configure(
)
for handler in local_handlers_:
callback_manager.add_handler(handler, False)
if inheritable_tags or local_tags:
callback_manager.add_tags(inheritable_tags or [])
callback_manager.add_tags(local_tags or [], False)
tracer = tracing_callback_var.get()
wandb_tracer = wandb_tracing_callback_var.get()

@ -85,6 +85,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
prompts: List[str],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
@ -101,6 +102,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order=execution_order,
child_execution_order=execution_order,
run_type=RunTypeEnum.llm,
tags=tags or [],
)
self._start_trace(llm_run)
self._on_llm_start(llm_run)
@ -145,6 +147,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
inputs: Dict[str, Any],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
@ -162,6 +165,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order,
child_runs=[],
run_type=RunTypeEnum.chain,
tags=tags or [],
)
self._start_trace(chain_run)
self._on_chain_start(chain_run)
@ -206,6 +210,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
input_str: str,
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
@ -223,6 +228,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order,
child_runs=[],
run_type=RunTypeEnum.tool,
tags=tags or [],
)
self._start_trace(tool_run)
self._on_tool_start(tool_run)

@ -59,6 +59,7 @@ class LangChainTracer(BaseTracer):
messages: List[List[BaseMessage]],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
@ -75,6 +76,7 @@ class LangChainTracer(BaseTracer):
execution_order=execution_order,
child_execution_order=execution_order,
run_type=RunTypeEnum.llm,
tags=tags,
)
self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run)

@ -94,6 +94,7 @@ class Run(BaseRunV2):
execution_order: int
child_execution_order: int
child_runs: List[Run] = Field(default_factory=list)
tags: Optional[List[str]] = Field(default_factory=list)
@root_validator(pre=True)
def assign_name(cls, values: dict) -> dict:

@ -36,6 +36,7 @@ class Chain(Serializable, ABC):
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
tags: Optional[List[str]] = None
class Config:
"""Configuration for this pydantic object."""
@ -111,6 +112,7 @@ class Chain(Serializable, ABC):
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@ -129,7 +131,7 @@ class Chain(Serializable, ABC):
"""
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
callbacks, self.callbacks, self.verbose, tags, self.tags
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
@ -159,6 +161,7 @@ class Chain(Serializable, ABC):
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@ -177,7 +180,7 @@ class Chain(Serializable, ABC):
"""
inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
callbacks, self.callbacks, self.verbose, tags, self.tags
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
@ -244,7 +247,13 @@ class Chain(Serializable, ABC):
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
def run(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
@ -255,10 +264,10 @@ class Chain(Serializable, ABC):
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0], callbacks=callbacks)[self.output_keys[0]]
return self(args[0], callbacks=callbacks, tags=tags)[self.output_keys[0]]
if kwargs and not args:
return self(kwargs, callbacks=callbacks)[self.output_keys[0]]
return self(kwargs, callbacks=callbacks, tags=tags)[self.output_keys[0]]
if not kwargs and not args:
raise ValueError(
@ -271,7 +280,13 @@ class Chain(Serializable, ABC):
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
async def arun(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
@ -282,10 +297,14 @@ class Chain(Serializable, ABC):
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]]
return (await self.acall(args[0], callbacks=callbacks, tags=tags))[
self.output_keys[0]
]
if kwargs and not args:
return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]]
return (await self.acall(kwargs, callbacks=callbacks, tags=tags))[
self.output_keys[0]
]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"

@ -98,7 +98,7 @@ class ConstitutionalChain(Chain):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
response = self.chain.run(
**inputs,
callbacks=_run_manager.get_child(),
callbacks=_run_manager.get_child("original"),
)
initial_response = response
input_prompt = self.chain.prompt.format(**inputs)
@ -116,7 +116,7 @@ class ConstitutionalChain(Chain):
input_prompt=input_prompt,
output_from_model=response,
critique_request=constitutional_principle.critique_request,
callbacks=_run_manager.get_child(),
callbacks=_run_manager.get_child("critique"),
)
critique = self._parse_critique(
output_string=raw_critique,
@ -137,7 +137,7 @@ class ConstitutionalChain(Chain):
critique_request=constitutional_principle.critique_request,
critique=critique,
revision_request=constitutional_principle.revision_request,
callbacks=_run_manager.get_child(),
callbacks=_run_manager.get_child("revision"),
).strip()
response = revision
critiques_and_revisions.append((critique, revision))

@ -283,7 +283,7 @@ class LLMChain(Chain):
return "llm_chain"
@classmethod
def from_string(cls, llm: BaseLanguageModel, template: str) -> Chain:
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)

@ -174,7 +174,7 @@ class SimpleSequentialChain(Chain):
_input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains):
_input = chain.run(_input, callbacks=_run_manager.get_child())
_input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}"))
if self.strip_outputs:
_input = _input.strip()
_run_manager.on_text(

@ -13,6 +13,8 @@ from langchain.callbacks.manager import (
tracing_v2_enabled,
)
from langchain.chains import LLMChain
from langchain.chains.constitutional_ai.base import ConstitutionalChain
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
@ -160,6 +162,25 @@ def test_tracing_v2_context_manager() -> None:
agent.run(questions[0]) # this should not be traced
def test_tracing_v2_chain_with_tags() -> None:
llm = OpenAI(temperature=0)
chain = ConstitutionalChain.from_llm(
llm,
chain=LLMChain.from_string(llm, "Q: {question} A:"),
tags=["only-root"],
constitutional_principles=[
ConstitutionalPrinciple(
critique_request="Tell if this answer is good.",
revision_request="Give a better answer.",
)
],
)
if "LANGCHAIN_TRACING_V2" in os.environ:
del os.environ["LANGCHAIN_TRACING_V2"]
with tracing_v2_enabled():
chain.run("what is the meaning of life", tags=["a-tag"])
def test_trace_as_group() -> None:
llm = OpenAI(temperature=0.9)
prompt = PromptTemplate(

@ -86,7 +86,7 @@ def test_callback_manager() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager = CallbackManager([handler1, handler2])
manager = CallbackManager(handlers=[handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
@ -143,7 +143,7 @@ async def test_async_callback_manager() -> None:
"""Test the AsyncCallbackManager."""
handler1 = FakeAsyncCallbackHandler()
handler2 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2])
manager = AsyncCallbackManager(handlers=[handler1, handler2])
await _test_callback_manager_async(manager, handler1, handler2)
@ -153,7 +153,7 @@ async def test_async_callback_manager_sync_handler() -> None:
handler1 = FakeCallbackHandler()
handler2 = FakeAsyncCallbackHandler()
handler3 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2, handler3])
manager = AsyncCallbackManager(handlers=[handler1, handler2, handler3])
await _test_callback_manager_async(manager, handler1, handler2, handler3)
@ -165,11 +165,11 @@ def test_callback_manager_inheritance() -> None:
FakeCallbackHandler(),
)
callback_manager1 = CallbackManager([handler1, handler2])
callback_manager1 = CallbackManager(handlers=[handler1, handler2])
assert callback_manager1.handlers == [handler1, handler2]
assert callback_manager1.inheritable_handlers == []
callback_manager2 = CallbackManager([])
callback_manager2 = CallbackManager(handlers=[])
assert callback_manager2.handlers == []
assert callback_manager2.inheritable_handlers == []
@ -229,7 +229,7 @@ def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None:
assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler)
assert isinstance(configured_manager, CallbackManager)
async_local_callbacks = AsyncCallbackManager([handler3, handler4])
async_local_callbacks = AsyncCallbackManager(handlers=[handler3, handler4])
async_configured_manager = AsyncCallbackManager.configure(
inheritable_callbacks=inheritable_callbacks,
local_callbacks=async_local_callbacks,

Loading…
Cancel
Save