feat(integrations): Add WandbTracer (#4521)

# WandbTracer
This PR adds the `WandbTracer` and deprecates the existing
`WandbCallbackHandler`.

Added an example notebook under the docs section alongside the
`LangchainTracer`
Here's an example
[colab](https://colab.research.google.com/drive/1pY13ym8ENEZ8Fh7nA99ILk2GcdUQu0jR?usp=sharing)
with the same notebook and the
[trace](https://wandb.ai/parambharat/langchain-tracing/runs/8i45cst6)
generated from the colab run


Co-authored-by: Bharat Ramanathan <ramanathan.parameshwaran@gohuddl.com>
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
searx_updates
Bharat Ramanathan 12 months ago committed by GitHub
parent 373ad49157
commit 22603d19e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because one or more lines are too long

@ -10,7 +10,9 @@
"\n",
"Run in Colab: https://colab.research.google.com/drive/1DXH4beT4HFaRKy_Vm4PoxhXVDRf7Ym8L?usp=sharing\n",
"\n",
"View Report: https://wandb.ai/a-sh0ts/langchain_callback_demo/reports/Prompt-Engineering-LLMs-with-LangChain-and-W-B--VmlldzozNjk1NTUw#👋-how-to-build-a-callback-in-langchain-for-better-prompt-engineering"
"View Report: https://wandb.ai/a-sh0ts/langchain_callback_demo/reports/Prompt-Engineering-LLMs-with-LangChain-and-W-B--VmlldzozNjk1NTUw#👋-how-to-build-a-callback-in-langchain-for-better-prompt-engineering\n",
"\n",
"**Note**: _the `WandbCallbackHandler` is being deprecated in favour of the `WandbTracer`_ . In future please use the `WandbTracer` as it is more flexible and allows for more granular logging. To know more about the `WandbTracer` refer to the agent_with_wandb_tracing.ipynb notebook in docs or use the following [colab](https://colab.research.google.com/drive/1pY13ym8ENEZ8Fh7nA99ILk2GcdUQu0jR?usp=sharing)."
]
},
{
@ -107,7 +109,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mharrison-chase\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
"\u001B[34m\u001B[1mwandb\u001B[0m: Currently logged in as: \u001B[33mharrison-chase\u001B[0m. Use \u001B[1m`wandb login --relogin`\u001B[0m to force relogin\n"
]
},
{
@ -174,7 +176,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m The wandb callback is currently in beta and is subject to change based on updates to `langchain`. Please report any issues to https://github.com/wandb/wandb/issues with the tag `langchain`.\n"
"\u001B[34m\u001B[1mwandb\u001B[0m: \u001B[33mWARNING\u001B[0m The wandb callback is currently in beta and is subject to change based on updates to `langchain`. Please report any issues to https://github.com/wandb/wandb/issues with the tag `langchain`.\n"
]
}
],
@ -521,20 +523,20 @@
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n",
"\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n",
"\u001B[32;1m\u001B[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n",
"Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mDiCaprio had a steady girlfriend in Camila Morrone. He had been with the model turned actress for nearly five years, as they were first said to be dating at the end of 2017. And the now 26-year-old Morrone is no stranger to Hollywood.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate her age raised to the 0.43 power.\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001B[0m\n",
"Observation: \u001B[36;1m\u001B[1;3mDiCaprio had a steady girlfriend in Camila Morrone. He had been with the model turned actress for nearly five years, as they were first said to be dating at the end of 2017. And the now 26-year-old Morrone is no stranger to Hollywood.\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3m I need to calculate her age raised to the 0.43 power.\n",
"Action: Calculator\n",
"Action Input: 26^0.43\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 4.059182145592686\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
"Final Answer: Leo DiCaprio's girlfriend is Camila Morrone and her current age raised to the 0.43 power is 4.059182145592686.\u001b[0m\n",
"Action Input: 26^0.43\u001B[0m\n",
"Observation: \u001B[33;1m\u001B[1;3mAnswer: 4.059182145592686\n",
"\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3m I now know the final answer.\n",
"Final Answer: Leo DiCaprio's girlfriend is Camila Morrone and her current age raised to the 0.43 power is 4.059182145592686.\u001B[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{

@ -6,6 +6,7 @@ from langchain.callbacks.comet_ml_callback import CometCallbackHandler
from langchain.callbacks.manager import (
get_openai_callback,
tracing_enabled,
wandb_tracing_enabled,
)
from langchain.callbacks.mlflow_callback import MlflowCallbackHandler
from langchain.callbacks.openai_info import OpenAICallbackHandler
@ -26,4 +27,5 @@ __all__ = [
"AsyncIteratorCallbackHandler",
"get_openai_callback",
"tracing_enabled",
"wandb_tracing_enabled",
]

@ -25,6 +25,7 @@ from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
from langchain.callbacks.tracers.schemas import TracerSession
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.callbacks.tracers.wandb import WandbTracer
from langchain.schema import (
AgentAction,
AgentFinish,
@ -44,6 +45,12 @@ tracing_callback_var: ContextVar[
] = ContextVar( # noqa: E501
"tracing_callback", default=None
)
wandb_tracing_callback_var: ContextVar[
Optional[WandbTracer]
] = ContextVar( # noqa: E501
"tracing_wandb_callback", default=None
)
tracing_v2_callback_var: ContextVar[
Optional[LangChainTracer]
] = ContextVar( # noqa: E501
@ -76,6 +83,17 @@ def tracing_enabled(
tracing_callback_var.set(None)
@contextmanager
def wandb_tracing_enabled(
session_name: str = "default",
) -> Generator[None, None, None]:
"""Get WandbTracer in a context manager."""
cb = WandbTracer()
wandb_tracing_callback_var.set(cb)
yield None
wandb_tracing_callback_var.set(None)
@contextmanager
def tracing_v2_enabled(
session_name: Optional[str] = None,
@ -831,12 +849,17 @@ def _configure(
callback_manager.add_handler(handler, False)
tracer = tracing_callback_var.get()
wandb_tracer = wandb_tracing_callback_var.get()
open_ai = openai_callback_var.get()
tracing_enabled_ = (
os.environ.get("LANGCHAIN_TRACING") is not None
or tracer is not None
or os.environ.get("LANGCHAIN_HANDLER") is not None
)
wandb_tracing_enabled_ = (
os.environ.get("LANGCHAIN_WANDB_TRACING") is not None
or wandb_tracer is not None
)
tracer_v2 = tracing_v2_callback_var.get()
tracing_v2_enabled_ = (
@ -851,6 +874,7 @@ def _configure(
or debug
or tracing_enabled_
or tracing_v2_enabled_
or wandb_tracing_enabled_
or open_ai is not None
):
if verbose and not any(
@ -876,6 +900,14 @@ def _configure(
handler = LangChainTracerV1()
handler.load_session(tracer_session)
callback_manager.add_handler(handler, True)
if wandb_tracing_enabled_ and not any(
isinstance(handler, WandbTracer) for handler in callback_manager.handlers
):
if wandb_tracer:
callback_manager.add_handler(wandb_tracer, True)
else:
handler = WandbTracer()
callback_manager.add_handler(handler, True)
if tracing_v2_enabled_ and not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers

@ -3,5 +3,11 @@
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.callbacks.tracers.wandb import WandbTracer
__all__ = ["LangChainTracer", "LangChainTracerV1", "ConsoleCallbackHandler"]
__all__ = [
"LangChainTracer",
"LangChainTracerV1",
"ConsoleCallbackHandler",
"WandbTracer",
]

@ -0,0 +1,265 @@
"""A Tracer Implementation that records activity to Weights & Biases."""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
TypedDict,
Union,
)
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
if TYPE_CHECKING:
from wandb import Settings as WBSettings
from wandb.sdk.data_types import trace_tree
from wandb.sdk.lib.paths import StrPath
from wandb.wandb_run import Run as WBRun
PRINT_WARNINGS = True
def _convert_lc_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
if run.run_type == RunTypeEnum.llm:
return _convert_llm_run_to_wb_span(trace_tree, run)
elif run.run_type == RunTypeEnum.chain:
return _convert_chain_run_to_wb_span(trace_tree, run)
elif run.run_type == RunTypeEnum.tool:
return _convert_tool_run_to_wb_span(trace_tree, run)
else:
return _convert_run_to_wb_span(trace_tree, run)
def _convert_llm_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
base_span = _convert_run_to_wb_span(trace_tree, run)
base_span.results = [
trace_tree.Result(
inputs={"prompt": prompt},
outputs={
f"gen_{g_i}": gen["text"]
for g_i, gen in enumerate(run.outputs["generations"][ndx])
}
if (
run.outputs is not None
and len(run.outputs["generations"]) > ndx
and len(run.outputs["generations"][ndx]) > 0
)
else None,
)
for ndx, prompt in enumerate(run.inputs["prompts"] or [])
]
base_span.span_kind = trace_tree.SpanKind.LLM
return base_span
def _convert_chain_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
base_span = _convert_run_to_wb_span(trace_tree, run)
base_span.results = [trace_tree.Result(inputs=run.inputs, outputs=run.outputs)]
base_span.child_spans = [
_convert_lc_run_to_wb_span(trace_tree, child_run)
for child_run in run.child_runs
]
base_span.span_kind = (
trace_tree.SpanKind.AGENT
if "agent" in run.serialized.get("name", "").lower()
else trace_tree.SpanKind.CHAIN
)
return base_span
def _convert_tool_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
base_span = _convert_run_to_wb_span(trace_tree, run)
base_span.results = [trace_tree.Result(inputs=run.inputs, outputs=run.outputs)]
base_span.child_spans = [
_convert_lc_run_to_wb_span(trace_tree, child_run)
for child_run in run.child_runs
]
base_span.span_kind = trace_tree.SpanKind.TOOL
return base_span
def _convert_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
attributes = {**run.extra} if run.extra else {}
attributes["execution_order"] = run.execution_order
return trace_tree.Span(
span_id=str(run.id) if run.id is not None else None,
name=run.serialized.get("name"),
start_time_ms=int(run.start_time.timestamp() * 1000),
end_time_ms=int(run.end_time.timestamp() * 1000),
status_code=trace_tree.StatusCode.SUCCESS
if run.error is None
else trace_tree.StatusCode.ERROR,
status_message=run.error,
attributes=attributes,
)
def _replace_type_with_kind(data: Any) -> Any:
if isinstance(data, dict):
# W&B TraceTree expects "_kind" instead of "_type" since `_type` is special
# in W&B.
if "_type" in data:
_type = data.pop("_type")
data["_kind"] = _type
return {k: _replace_type_with_kind(v) for k, v in data.items()}
elif isinstance(data, list):
return [_replace_type_with_kind(v) for v in data]
elif isinstance(data, tuple):
return tuple(_replace_type_with_kind(v) for v in data)
elif isinstance(data, set):
return {_replace_type_with_kind(v) for v in data}
else:
return data
class WandbRunArgs(TypedDict):
job_type: Optional[str]
dir: Optional[StrPath]
config: Union[Dict, str, None]
project: Optional[str]
entity: Optional[str]
reinit: Optional[bool]
tags: Optional[Sequence]
group: Optional[str]
name: Optional[str]
notes: Optional[str]
magic: Optional[Union[dict, str, bool]]
config_exclude_keys: Optional[List[str]]
config_include_keys: Optional[List[str]]
anonymous: Optional[str]
mode: Optional[str]
allow_val_change: Optional[bool]
resume: Optional[Union[bool, str]]
force: Optional[bool]
tensorboard: Optional[bool]
sync_tensorboard: Optional[bool]
monitor_gym: Optional[bool]
save_code: Optional[bool]
id: Optional[str]
settings: Union[WBSettings, Dict[str, Any], None]
class WandbTracer(BaseTracer):
"""Callback Handler that logs to Weights and Biases.
This handler will log the model architecture and run traces to Weights and Biases.
This will ensure that all LangChain activity is logged to W&B.
"""
_run: Optional[WBRun] = None
_run_args: Optional[WandbRunArgs] = None
def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None:
"""Initializes the WandbTracer.
Parameters:
run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not
provided, `wandb.init()` will be called with no arguments. Please
refer to the `wandb.init` for more details.
To use W&B to monitor all LangChain activity, add this tracer like any other
LangChain callback:
```
from wandb.integration.langchain import WandbTracer
tracer = WandbTracer()
chain = LLMChain(llm, callbacks=[tracer])
# ...end of notebook / script:
tracer.finish()
```
"""
super().__init__(**kwargs)
try:
import wandb
from wandb.sdk.data_types import trace_tree
except ImportError as e:
raise ImportError(
"Could not import wandb python package."
"Please install it with `pip install wandb`."
) from e
self._wandb = wandb
self._trace_tree = trace_tree
self._run_args = run_args
self._ensure_run(should_print_url=(wandb.run is None))
def finish(self) -> None:
"""Waits for all asynchronous processes to finish and data to upload.
Proxy for `wandb.finish()`.
"""
self._wandb.finish()
def _log_trace_from_run(self, run: Run) -> None:
"""Logs a LangChain Run to W*B as a W&B Trace."""
self._ensure_run()
try:
root_span = _convert_lc_run_to_wb_span(self._trace_tree, run)
except Exception as e:
if PRINT_WARNINGS:
self._wandb.termwarn(
f"Skipping trace saving - unable to safely convert LangChain Run "
f"into W&B Trace due to: {e}"
)
return
model_dict = None
# TODO: Add something like this once we have a way to get the clean serialized
# parent dict from a run:
# serialized_parent = safely_get_span_producing_model(run)
# if serialized_parent is not None:
# model_dict = safely_convert_model_to_dict(serialized_parent)
model_trace = self._trace_tree.WBTraceTree(
root_span=root_span,
model_dict=model_dict,
)
if self._wandb.run is not None:
self._wandb.run.log({"langchain_trace": model_trace})
def _ensure_run(self, should_print_url: bool = False) -> None:
"""Ensures an active W&B run exists.
If not, will start a new run with the provided run_args.
"""
if self._wandb.run is None:
# Make a shallow copy of the run args, so we don't modify the original
run_args = self._run_args or {} # type: ignore
run_args: dict = {**run_args} # type: ignore
# Prefer to run in silent mode since W&B has a lot of output
# which can be undesirable when dealing with text-based models.
if "settings" not in run_args: # type: ignore
run_args["settings"] = {"silent": True} # type: ignore
# Start the run and add the stream table
self._wandb.init(**run_args)
if self._wandb.run is not None:
if should_print_url:
run_url = self._wandb.run.settings.run_url
self._wandb.termlog(
f"Streaming LangChain activity to W&B at {run_url}\n"
"`WandbTracer` is currently in beta.\n"
"Please report any issues to "
"https://github.com/wandb/wandb/issues with the tag "
"`langchain`."
)
self._wandb.run._label(repo="langchain")
def _persist_run(self, run: "Run") -> None:
"""Persist a run."""
self._log_trace_from_run(run)

@ -200,9 +200,9 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
notes=self.notes,
)
warning = (
"The wandb callback is currently in beta and is subject to change "
"based on updates to `langchain`. Please report any issues to "
"https://github.com/wandb/wandb/issues with the tag `langchain`."
"DEPRECATION: The `WandbCallbackHandler` will soon be deprecated in favor "
"of the `WandbTracer`. Please update your code to use the `WandbTracer` "
"instead."
)
wandb.termwarn(
warning,

@ -1,6 +1,7 @@
"""Generic utility functions."""
import contextlib
import datetime
import importlib
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
@ -115,3 +116,18 @@ def mock_now(dt_value): # type: ignore
yield datetime.datetime
finally:
datetime.datetime = real_datetime
def guard_import(
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
) -> Any:
"""Dynamically imports a module and raises a helpful exception if the module is not
installed."""
try:
module = importlib.import_module(module_name, package)
except ImportError:
raise ImportError(
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name or module_name}`."
)
return module

@ -3,7 +3,6 @@
The vector store can be persisted in json, bson or parquet format.
"""
import importlib
import json
import math
import os
@ -13,6 +12,7 @@ from uuid import uuid4
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import guard_import
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
@ -20,21 +20,6 @@ DEFAULT_K = 4 # Number of Documents to return.
DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
def guard_import(
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
) -> Any:
"""Dynamically imports a module and raises a helpful exception if the module is not
installed."""
try:
module = importlib.import_module(module_name, package)
except ImportError:
raise ImportError(
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name or module_name}`."
)
return module
class BaseSerializer(ABC):
"""Abstract base class for saving and loading data."""

@ -0,0 +1,117 @@
"""Integration tests for the langchain tracer module."""
import asyncio
import os
import pytest
from aiohttp import ClientSession
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks.manager import wandb_tracing_enabled
from langchain.llms import OpenAI
questions = [
(
"Who won the US Open men's final in 2019? "
"What is his age raised to the 0.334 power?"
),
(
"Who is Olivia Wilde's boyfriend? "
"What is his current age raised to the 0.23 power?"
),
(
"Who won the most recent formula 1 grand prix? "
"What is their age raised to the 0.23 power?"
),
(
"Who won the US Open women's final in 2019? "
"What is her age raised to the 0.34 power?"
),
("Who is Beyonce's husband? " "What is his age raised to the 0.19 power?"),
]
def test_tracing_sequential() -> None:
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
os.environ["WANDB_PROJECT"] = "langchain-tracing"
for q in questions[:3]:
llm = OpenAI(temperature=0)
tools = load_tools(
["llm-math", "serpapi"],
llm=llm,
)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
agent.run(q)
def test_tracing_session_env_var() -> None:
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
llm = OpenAI(temperature=0)
tools = load_tools(
["llm-math", "serpapi"],
llm=llm,
)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
agent.run(questions[0])
@pytest.mark.asyncio
async def test_tracing_concurrent() -> None:
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
aiosession = ClientSession()
llm = OpenAI(temperature=0)
async_tools = load_tools(
["llm-math", "serpapi"],
llm=llm,
aiosession=aiosession,
)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
tasks = [agent.arun(q) for q in questions[:3]]
await asyncio.gather(*tasks)
await aiosession.close()
def test_tracing_context_manager() -> None:
llm = OpenAI(temperature=0)
tools = load_tools(
["llm-math", "serpapi"],
llm=llm,
)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_WANDB_TRACING" in os.environ:
del os.environ["LANGCHAIN_WANDB_TRACING"]
with wandb_tracing_enabled():
agent.run(questions[0]) # this should be traced
agent.run(questions[0]) # this should not be traced
@pytest.mark.asyncio
async def test_tracing_context_manager_async() -> None:
llm = OpenAI(temperature=0)
async_tools = load_tools(
["llm-math", "serpapi"],
llm=llm,
)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_WANDB_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
# start a background task
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
with wandb_tracing_enabled():
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
await asyncio.gather(*tasks)
await task
Loading…
Cancel
Save