Disable trace_on_chain_group auto-tracing (#12807)

Previously we treated trace_on_chain_group as a command to always start
tracing. This is unintuitive (makes the function do 2 things), and makes
it harder to toggle tracing
pull/12666/head^2
William FH 11 months ago committed by GitHub
parent 0da75b9ebd
commit 18005c6384
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -124,9 +124,11 @@ def tracing_enabled(
"""
cb = LangChainTracerV1()
session = cast(TracerSessionV1, cb.load_session(session_name))
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
try:
tracing_callback_var.set(cb)
yield session
finally:
tracing_callback_var.set(None)
@contextmanager
@ -191,9 +193,11 @@ def tracing_v2_enabled(
tags=tags,
client=client,
)
tracing_v2_callback_var.set(cb)
yield cb
tracing_v2_callback_var.set(None)
try:
tracing_v2_callback_var.set(cb)
yield cb
finally:
tracing_v2_callback_var.set(None)
@contextmanager
@ -214,6 +218,33 @@ def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None,
run_collector_var.set(None)
def _get_trace_callbacks(
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None,
) -> Callbacks:
if _tracing_v2_is_enabled():
project_name_ = project_name or _get_tracer_project()
tracer = tracing_v2_callback_var.get() or LangChainTracer(
project_name=project_name_,
example_id=example_id,
)
if callback_manager is None:
cb = cast(Callbacks, [tracer])
else:
if not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
callback_manager.add_handler(tracer, True)
# If it already has a LangChainTracer, we don't need to add another one.
# this would likely mess up the trace hierarchy.
cb = callback_manager
else:
cb = None
return cb
@contextmanager
def trace_as_chain_group(
group_name: str,
@ -241,6 +272,8 @@ def trace_as_chain_group(
tags (List[str], optional): The inheritable tags to apply to all runs.
Defaults to None.
Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith.
Returns:
CallbackManagerForChainGroup: The callback manager for the chain group.
@ -253,16 +286,8 @@ def trace_as_chain_group(
res = llm.predict(llm_input, callbacks=manager)
manager.on_chain_end({"output": res})
""" # noqa: E501
cb = cast(
Callbacks,
[
LangChainTracer(
project_name=project_name,
example_id=example_id,
)
]
if callback_manager is None
else callback_manager,
cb = _get_trace_callbacks(
project_name, example_id, callback_manager=callback_manager
)
cm = CallbackManager.configure(
inheritable_callbacks=cb,
@ -321,6 +346,8 @@ async def atrace_as_chain_group(
Returns:
AsyncCallbackManager: The async callback manager for the chain group.
Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith.
Example:
.. code-block:: python
@ -330,16 +357,8 @@ async def atrace_as_chain_group(
res = await llm.apredict(llm_input, callbacks=manager)
await manager.on_chain_end({"output": res})
""" # noqa: E501
cb = cast(
Callbacks,
[
LangChainTracer(
project_name=project_name,
example_id=example_id,
)
]
if callback_manager is None
else callback_manager,
cb = _get_trace_callbacks(
project_name, example_id, callback_manager=callback_manager
)
cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags)
@ -1895,6 +1914,32 @@ def env_var_is_set(env_var: str) -> bool:
)
def _tracing_v2_is_enabled() -> bool:
return (
env_var_is_set("LANGCHAIN_TRACING_V2")
or tracing_v2_callback_var.get() is not None
or get_run_tree_context() is not None
)
def _get_tracer_project() -> str:
run_tree = get_run_tree_context()
return getattr(
run_tree,
"session_name",
getattr(
# Note, if people are trying to nest @traceable functions and the
# tracing_v2_enabled context manager, this will likely mess up the
# tree structure.
tracing_v2_callback_var.get(),
"project",
os.environ.get(
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
),
),
)
def _configure(
callback_manager_cls: Type[T],
inheritable_callbacks: Callbacks = None,
@ -1973,18 +2018,8 @@ def _configure(
)
tracer_v2 = tracing_v2_callback_var.get()
tracing_v2_enabled_ = (
env_var_is_set("LANGCHAIN_TRACING_V2")
or tracer_v2 is not None
or run_tree is not None
)
tracer_project = getattr(
run_tree,
"session_name",
os.environ.get(
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
),
)
tracing_v2_enabled_ = _tracing_v2_is_enabled()
tracer_project = _get_tracer_project()
run_collector_ = run_collector_var.get()
debug = _get_debug()
if (

@ -1,5 +1,6 @@
"""Test CallbackManager."""
from typing import List, Tuple
from unittest.mock import patch
import pytest
@ -9,9 +10,10 @@ from langchain.callbacks.manager import (
CallbackManager,
get_openai_callback,
trace_as_chain_group,
tracing_v2_enabled,
)
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
from langchain.llms.openai import BaseOpenAI
from langchain.schema import AgentAction, AgentFinish, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import (
@ -303,70 +305,104 @@ def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None:
def test_callback_manager_configure_context_vars(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test callback manager configuration."""
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "true")
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
with patch.object(LangChainTracer, "_update_run_single"):
with patch.object(LangChainTracer, "_persist_run_single"):
with trace_as_chain_group("test") as group_manager:
assert len(group_manager.handlers) == 1
tracer = group_manager.handlers[0]
assert isinstance(tracer, LangChainTracer)
with get_openai_callback() as cb:
# This is a new empty callback handler
assert cb.successful_requests == 0
assert cb.total_tokens == 0
# configure adds this openai cb but doesn't modify the group manager
mngr = CallbackManager.configure(group_manager)
assert mngr.handlers == [tracer, cb]
assert group_manager.handlers == [tracer]
response = LLMResult(
generations=[],
llm_output={
"token_usage": {
"prompt_tokens": 2,
"completion_tokens": 1,
"total_tokens": 3,
},
"model_name": BaseOpenAI.__fields__["model_name"].default,
},
)
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
# The callback handler has been updated
assert cb.successful_requests == 1
assert cb.total_tokens == 3
assert cb.prompt_tokens == 2
assert cb.completion_tokens == 1
assert cb.total_cost > 0
with get_openai_callback() as cb:
# This is a new empty callback handler
assert cb.successful_requests == 0
assert cb.total_tokens == 0
# configure adds this openai cb but doesn't modify the group manager
mngr = CallbackManager.configure(group_manager)
assert mngr.handlers == [tracer, cb]
assert group_manager.handlers == [tracer]
response = LLMResult(
generations=[],
llm_output={
"token_usage": {
"prompt_tokens": 2,
"completion_tokens": 1,
"total_tokens": 3,
},
"model_name": BaseOpenAI.__fields__["model_name"].default,
},
)
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
# The callback handler has been updated
assert cb.successful_requests == 1
assert cb.total_tokens == 3
assert cb.prompt_tokens == 2
assert cb.completion_tokens == 1
assert cb.total_cost > 0
wait_for_all_tracers()
assert LangChainTracer._persist_run_single.call_count == 1 # type: ignore
def test_trace_as_chain_group_within_tracing_v2_context_manager(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test callback manager configuration."""
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false")
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
with trace_as_chain_group("test") as group_manager:
assert len(group_manager.handlers) == 1
tracer = group_manager.handlers[0]
assert isinstance(tracer, LangChainTracer)
with get_openai_callback() as cb:
# This is a new empty callback handler
assert cb.successful_requests == 0
assert cb.total_tokens == 0
# configure adds this openai cb but doesn't modify the group manager
mngr = CallbackManager.configure(group_manager)
assert mngr.handlers == [tracer, cb]
assert group_manager.handlers == [tracer]
response = LLMResult(
generations=[],
llm_output={
"token_usage": {
"prompt_tokens": 2,
"completion_tokens": 1,
"total_tokens": 3,
},
"model_name": BaseOpenAI.__fields__["model_name"].default,
},
)
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
# The callback handler has been updated
assert cb.successful_requests == 1
assert cb.total_tokens == 3
assert cb.prompt_tokens == 2
assert cb.completion_tokens == 1
assert cb.total_cost > 0
with get_openai_callback() as cb:
# This is a new empty callback handler
assert cb.successful_requests == 0
assert cb.total_tokens == 0
# configure adds this openai cb but doesn't modify the group manager
mngr = CallbackManager.configure(group_manager)
assert mngr.handlers == [tracer, cb]
assert group_manager.handlers == [tracer]
response = LLMResult(
generations=[],
llm_output={
"token_usage": {
"prompt_tokens": 2,
"completion_tokens": 1,
"total_tokens": 3,
},
"model_name": BaseOpenAI.__fields__["model_name"].default,
},
)
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
# The callback handler has been updated
assert cb.successful_requests == 1
assert cb.total_tokens == 3
assert cb.prompt_tokens == 2
assert cb.completion_tokens == 1
assert cb.total_cost > 0
with tracing_v2_enabled():
with patch.object(LangChainTracer, "_update_run_single"):
with patch.object(LangChainTracer, "_persist_run_single"):
with trace_as_chain_group("test") as group_manager:
assert len(group_manager.handlers) == 1
tracer = group_manager.handlers[0]
assert isinstance(tracer, LangChainTracer)
wait_for_all_tracers()
assert LangChainTracer._persist_run_single.call_count == 1 # type: ignore
def test_trace_as_chain_group_tracing_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test callback manager configuration."""
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false")
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
with patch.object(LangChainTracer, "_update_run_single"):
with patch.object(LangChainTracer, "_persist_run_single"):
with trace_as_chain_group("test") as group_manager:
assert len(group_manager.handlers) == 0
assert LangChainTracer._persist_run_single.call_count == 0 # type: ignore

Loading…
Cancel
Save