forked from Archives/langchain
feat(integrations): Add WandbTracer (#4521)
# WandbTracer This PR adds the `WandbTracer` and deprecates the existing `WandbCallbackHandler`. Added an example notebook under the docs section alongside the `LangchainTracer` Here's an example [colab](https://colab.research.google.com/drive/1pY13ym8ENEZ8Fh7nA99ILk2GcdUQu0jR?usp=sharing) with the same notebook and the [trace](https://wandb.ai/parambharat/langchain-tracing/runs/8i45cst6) generated from the colab run Co-authored-by: Bharat Ramanathan <ramanathan.parameshwaran@gohuddl.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>searx_updates
parent
373ad49157
commit
22603d19e0
File diff suppressed because one or more lines are too long
@ -0,0 +1,265 @@
|
||||
"""A Tracer Implementation that records activity to Weights & Biases."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from wandb import Settings as WBSettings
|
||||
from wandb.sdk.data_types import trace_tree
|
||||
from wandb.sdk.lib.paths import StrPath
|
||||
from wandb.wandb_run import Run as WBRun
|
||||
|
||||
|
||||
PRINT_WARNINGS = True
|
||||
|
||||
|
||||
def _convert_lc_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
||||
if run.run_type == RunTypeEnum.llm:
|
||||
return _convert_llm_run_to_wb_span(trace_tree, run)
|
||||
elif run.run_type == RunTypeEnum.chain:
|
||||
return _convert_chain_run_to_wb_span(trace_tree, run)
|
||||
elif run.run_type == RunTypeEnum.tool:
|
||||
return _convert_tool_run_to_wb_span(trace_tree, run)
|
||||
else:
|
||||
return _convert_run_to_wb_span(trace_tree, run)
|
||||
|
||||
|
||||
def _convert_llm_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
||||
base_span = _convert_run_to_wb_span(trace_tree, run)
|
||||
|
||||
base_span.results = [
|
||||
trace_tree.Result(
|
||||
inputs={"prompt": prompt},
|
||||
outputs={
|
||||
f"gen_{g_i}": gen["text"]
|
||||
for g_i, gen in enumerate(run.outputs["generations"][ndx])
|
||||
}
|
||||
if (
|
||||
run.outputs is not None
|
||||
and len(run.outputs["generations"]) > ndx
|
||||
and len(run.outputs["generations"][ndx]) > 0
|
||||
)
|
||||
else None,
|
||||
)
|
||||
for ndx, prompt in enumerate(run.inputs["prompts"] or [])
|
||||
]
|
||||
base_span.span_kind = trace_tree.SpanKind.LLM
|
||||
|
||||
return base_span
|
||||
|
||||
|
||||
def _convert_chain_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
||||
base_span = _convert_run_to_wb_span(trace_tree, run)
|
||||
|
||||
base_span.results = [trace_tree.Result(inputs=run.inputs, outputs=run.outputs)]
|
||||
base_span.child_spans = [
|
||||
_convert_lc_run_to_wb_span(trace_tree, child_run)
|
||||
for child_run in run.child_runs
|
||||
]
|
||||
base_span.span_kind = (
|
||||
trace_tree.SpanKind.AGENT
|
||||
if "agent" in run.serialized.get("name", "").lower()
|
||||
else trace_tree.SpanKind.CHAIN
|
||||
)
|
||||
|
||||
return base_span
|
||||
|
||||
|
||||
def _convert_tool_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
||||
base_span = _convert_run_to_wb_span(trace_tree, run)
|
||||
base_span.results = [trace_tree.Result(inputs=run.inputs, outputs=run.outputs)]
|
||||
base_span.child_spans = [
|
||||
_convert_lc_run_to_wb_span(trace_tree, child_run)
|
||||
for child_run in run.child_runs
|
||||
]
|
||||
base_span.span_kind = trace_tree.SpanKind.TOOL
|
||||
|
||||
return base_span
|
||||
|
||||
|
||||
def _convert_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
||||
attributes = {**run.extra} if run.extra else {}
|
||||
attributes["execution_order"] = run.execution_order
|
||||
|
||||
return trace_tree.Span(
|
||||
span_id=str(run.id) if run.id is not None else None,
|
||||
name=run.serialized.get("name"),
|
||||
start_time_ms=int(run.start_time.timestamp() * 1000),
|
||||
end_time_ms=int(run.end_time.timestamp() * 1000),
|
||||
status_code=trace_tree.StatusCode.SUCCESS
|
||||
if run.error is None
|
||||
else trace_tree.StatusCode.ERROR,
|
||||
status_message=run.error,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
|
||||
def _replace_type_with_kind(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
# W&B TraceTree expects "_kind" instead of "_type" since `_type` is special
|
||||
# in W&B.
|
||||
if "_type" in data:
|
||||
_type = data.pop("_type")
|
||||
data["_kind"] = _type
|
||||
return {k: _replace_type_with_kind(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [_replace_type_with_kind(v) for v in data]
|
||||
elif isinstance(data, tuple):
|
||||
return tuple(_replace_type_with_kind(v) for v in data)
|
||||
elif isinstance(data, set):
|
||||
return {_replace_type_with_kind(v) for v in data}
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
class WandbRunArgs(TypedDict):
|
||||
job_type: Optional[str]
|
||||
dir: Optional[StrPath]
|
||||
config: Union[Dict, str, None]
|
||||
project: Optional[str]
|
||||
entity: Optional[str]
|
||||
reinit: Optional[bool]
|
||||
tags: Optional[Sequence]
|
||||
group: Optional[str]
|
||||
name: Optional[str]
|
||||
notes: Optional[str]
|
||||
magic: Optional[Union[dict, str, bool]]
|
||||
config_exclude_keys: Optional[List[str]]
|
||||
config_include_keys: Optional[List[str]]
|
||||
anonymous: Optional[str]
|
||||
mode: Optional[str]
|
||||
allow_val_change: Optional[bool]
|
||||
resume: Optional[Union[bool, str]]
|
||||
force: Optional[bool]
|
||||
tensorboard: Optional[bool]
|
||||
sync_tensorboard: Optional[bool]
|
||||
monitor_gym: Optional[bool]
|
||||
save_code: Optional[bool]
|
||||
id: Optional[str]
|
||||
settings: Union[WBSettings, Dict[str, Any], None]
|
||||
|
||||
|
||||
class WandbTracer(BaseTracer):
|
||||
"""Callback Handler that logs to Weights and Biases.
|
||||
|
||||
This handler will log the model architecture and run traces to Weights and Biases.
|
||||
This will ensure that all LangChain activity is logged to W&B.
|
||||
"""
|
||||
|
||||
_run: Optional[WBRun] = None
|
||||
_run_args: Optional[WandbRunArgs] = None
|
||||
|
||||
def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None:
|
||||
"""Initializes the WandbTracer.
|
||||
|
||||
Parameters:
|
||||
run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not
|
||||
provided, `wandb.init()` will be called with no arguments. Please
|
||||
refer to the `wandb.init` for more details.
|
||||
|
||||
To use W&B to monitor all LangChain activity, add this tracer like any other
|
||||
LangChain callback:
|
||||
```
|
||||
from wandb.integration.langchain import WandbTracer
|
||||
|
||||
tracer = WandbTracer()
|
||||
chain = LLMChain(llm, callbacks=[tracer])
|
||||
# ...end of notebook / script:
|
||||
tracer.finish()
|
||||
```
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import wandb
|
||||
from wandb.sdk.data_types import trace_tree
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import wandb python package."
|
||||
"Please install it with `pip install wandb`."
|
||||
) from e
|
||||
self._wandb = wandb
|
||||
self._trace_tree = trace_tree
|
||||
self._run_args = run_args
|
||||
self._ensure_run(should_print_url=(wandb.run is None))
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Waits for all asynchronous processes to finish and data to upload.
|
||||
|
||||
Proxy for `wandb.finish()`.
|
||||
"""
|
||||
self._wandb.finish()
|
||||
|
||||
def _log_trace_from_run(self, run: Run) -> None:
|
||||
"""Logs a LangChain Run to W*B as a W&B Trace."""
|
||||
self._ensure_run()
|
||||
|
||||
try:
|
||||
root_span = _convert_lc_run_to_wb_span(self._trace_tree, run)
|
||||
except Exception as e:
|
||||
if PRINT_WARNINGS:
|
||||
self._wandb.termwarn(
|
||||
f"Skipping trace saving - unable to safely convert LangChain Run "
|
||||
f"into W&B Trace due to: {e}"
|
||||
)
|
||||
return
|
||||
|
||||
model_dict = None
|
||||
|
||||
# TODO: Add something like this once we have a way to get the clean serialized
|
||||
# parent dict from a run:
|
||||
# serialized_parent = safely_get_span_producing_model(run)
|
||||
# if serialized_parent is not None:
|
||||
# model_dict = safely_convert_model_to_dict(serialized_parent)
|
||||
|
||||
model_trace = self._trace_tree.WBTraceTree(
|
||||
root_span=root_span,
|
||||
model_dict=model_dict,
|
||||
)
|
||||
if self._wandb.run is not None:
|
||||
self._wandb.run.log({"langchain_trace": model_trace})
|
||||
|
||||
def _ensure_run(self, should_print_url: bool = False) -> None:
|
||||
"""Ensures an active W&B run exists.
|
||||
|
||||
If not, will start a new run with the provided run_args.
|
||||
"""
|
||||
if self._wandb.run is None:
|
||||
# Make a shallow copy of the run args, so we don't modify the original
|
||||
run_args = self._run_args or {} # type: ignore
|
||||
run_args: dict = {**run_args} # type: ignore
|
||||
|
||||
# Prefer to run in silent mode since W&B has a lot of output
|
||||
# which can be undesirable when dealing with text-based models.
|
||||
if "settings" not in run_args: # type: ignore
|
||||
run_args["settings"] = {"silent": True} # type: ignore
|
||||
|
||||
# Start the run and add the stream table
|
||||
self._wandb.init(**run_args)
|
||||
if self._wandb.run is not None:
|
||||
if should_print_url:
|
||||
run_url = self._wandb.run.settings.run_url
|
||||
self._wandb.termlog(
|
||||
f"Streaming LangChain activity to W&B at {run_url}\n"
|
||||
"`WandbTracer` is currently in beta.\n"
|
||||
"Please report any issues to "
|
||||
"https://github.com/wandb/wandb/issues with the tag "
|
||||
"`langchain`."
|
||||
)
|
||||
|
||||
self._wandb.run._label(repo="langchain")
|
||||
|
||||
def _persist_run(self, run: "Run") -> None:
|
||||
"""Persist a run."""
|
||||
self._log_trace_from_run(run)
|
@ -0,0 +1,117 @@
|
||||
"""Integration tests for the langchain tracer module."""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
from langchain.callbacks.manager import wandb_tracing_enabled
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
questions = [
|
||||
(
|
||||
"Who won the US Open men's final in 2019? "
|
||||
"What is his age raised to the 0.334 power?"
|
||||
),
|
||||
(
|
||||
"Who is Olivia Wilde's boyfriend? "
|
||||
"What is his current age raised to the 0.23 power?"
|
||||
),
|
||||
(
|
||||
"Who won the most recent formula 1 grand prix? "
|
||||
"What is their age raised to the 0.23 power?"
|
||||
),
|
||||
(
|
||||
"Who won the US Open women's final in 2019? "
|
||||
"What is her age raised to the 0.34 power?"
|
||||
),
|
||||
("Who is Beyonce's husband? " "What is his age raised to the 0.19 power?"),
|
||||
]
|
||||
|
||||
|
||||
def test_tracing_sequential() -> None:
|
||||
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
||||
os.environ["WANDB_PROJECT"] = "langchain-tracing"
|
||||
|
||||
for q in questions[:3]:
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(
|
||||
["llm-math", "serpapi"],
|
||||
llm=llm,
|
||||
)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(q)
|
||||
|
||||
|
||||
def test_tracing_session_env_var() -> None:
|
||||
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
||||
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(
|
||||
["llm-math", "serpapi"],
|
||||
llm=llm,
|
||||
)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(questions[0])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracing_concurrent() -> None:
|
||||
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
||||
aiosession = ClientSession()
|
||||
llm = OpenAI(temperature=0)
|
||||
async_tools = load_tools(
|
||||
["llm-math", "serpapi"],
|
||||
llm=llm,
|
||||
aiosession=aiosession,
|
||||
)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
tasks = [agent.arun(q) for q in questions[:3]]
|
||||
await asyncio.gather(*tasks)
|
||||
await aiosession.close()
|
||||
|
||||
|
||||
def test_tracing_context_manager() -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(
|
||||
["llm-math", "serpapi"],
|
||||
llm=llm,
|
||||
)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
if "LANGCHAIN_WANDB_TRACING" in os.environ:
|
||||
del os.environ["LANGCHAIN_WANDB_TRACING"]
|
||||
with wandb_tracing_enabled():
|
||||
agent.run(questions[0]) # this should be traced
|
||||
|
||||
agent.run(questions[0]) # this should not be traced
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracing_context_manager_async() -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
async_tools = load_tools(
|
||||
["llm-math", "serpapi"],
|
||||
llm=llm,
|
||||
)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
if "LANGCHAIN_WANDB_TRACING" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING"]
|
||||
|
||||
# start a background task
|
||||
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
|
||||
with wandb_tracing_enabled():
|
||||
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
await task
|
Loading…
Reference in New Issue