@ -1,6 +1,7 @@
""" A Tracer Implementation that records activity to Weights & Biases. """
from __future__ import annotations
import json
from typing import (
TYPE_CHECKING ,
Any ,
@ -8,6 +9,7 @@ from typing import (
List ,
Optional ,
Sequence ,
Tuple ,
TypedDict ,
Union ,
)
@ -17,7 +19,7 @@ from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
if TYPE_CHECKING :
from wandb import Settings as WBSettings
from wandb . sdk . data_types import trace_tree
from wandb . sdk . data_types . trace_tree import Span
from wandb . sdk . lib . paths import StrPath
from wandb . wandb_run import Run as WBRun
@ -25,115 +27,350 @@ if TYPE_CHECKING:
PRINT_WARNINGS = True
def _convert_lc_run_to_wb_span ( trace_tree : Any , run : Run ) - > trace_tree . Span :
if run . run_type == RunTypeEnum . llm :
return _convert_llm_run_to_wb_span ( trace_tree , run )
elif run . run_type == RunTypeEnum . chain :
return _convert_chain_run_to_wb_span ( trace_tree , run )
elif run . run_type == RunTypeEnum . tool :
return _convert_tool_run_to_wb_span ( trace_tree , run )
def _serialize_inputs ( run_inputs : dict ) - > dict :
if " input_documents " in run_inputs :
docs = run_inputs [ " input_documents " ]
return { f " input_document_ { i } " : doc . json ( ) for i , doc in enumerate ( docs ) }
else :
return _convert_run_to_wb_span( trace_tree , run )
return run_inputs
def _convert_llm_run_to_wb_span ( trace_tree : Any , run : Run ) - > trace_tree . Span :
base_span = _convert_run_to_wb_span ( trace_tree , run )
class RunProcessor :
""" Handles the conversion of a LangChain Runs into a WBTraceTree. """
base_span . results = [
trace_tree . Result (
inputs = { " prompt " : prompt } ,
outputs = {
f " gen_ { g_i } " : gen [ " text " ]
for g_i , gen in enumerate ( run . outputs [ " generations " ] [ ndx ] )
}
if (
run . outputs is not None
and len ( run . outputs [ " generations " ] ) > ndx
and len ( run . outputs [ " generations " ] [ ndx ] ) > 0
def __init__ ( self , wandb_module : Any , trace_module : Any ) :
self . wandb = wandb_module
self . trace_tree = trace_module
def process_span ( self , run : Run ) - > Optional [ " Span " ] :
""" Converts a LangChain Run into a W&B Trace Span.
: param run : The LangChain Run to convert .
: return : The converted W & B Trace Span .
"""
try :
span = self . _convert_lc_run_to_wb_span ( run )
return span
except Exception as e :
if PRINT_WARNINGS :
self . wandb . termwarn (
f " Skipping trace saving - unable to safely convert LangChain Run "
f " into W&B Trace due to: { e } "
)
return None
def _convert_run_to_wb_span ( self , run : Run ) - > " Span " :
""" Base utility to create a span from a run.
: param run : The run to convert .
: return : The converted Span .
"""
attributes = { * * run . extra } if run . extra else { }
attributes [ " execution_order " ] = run . execution_order
return self . trace_tree . Span (
span_id = str ( run . id ) if run . id is not None else None ,
name = run . name ,
start_time_ms = int ( run . start_time . timestamp ( ) * 1000 ) ,
end_time_ms = int ( run . end_time . timestamp ( ) * 1000 ) ,
status_code = self . trace_tree . StatusCode . SUCCESS
if run . error is None
else self . trace_tree . StatusCode . ERROR ,
status_message = run . error ,
attributes = attributes ,
)
def _convert_llm_run_to_wb_span ( self , run : Run ) - > " Span " :
""" Converts a LangChain LLM Run into a W&B Trace Span.
: param run : The LangChain LLM Run to convert .
: return : The converted W & B Trace Span .
"""
base_span = self . _convert_run_to_wb_span ( run )
if base_span . attributes is None :
base_span . attributes = { }
base_span . attributes [ " llm_output " ] = run . outputs . get ( " llm_output " , { } )
base_span . results = [
self . trace_tree . Result (
inputs = { " prompt " : prompt } ,
outputs = {
f " gen_ { g_i } " : gen [ " text " ]
for g_i , gen in enumerate ( run . outputs [ " generations " ] [ ndx ] )
}
if (
run . outputs is not None
and len ( run . outputs [ " generations " ] ) > ndx
and len ( run . outputs [ " generations " ] [ ndx ] ) > 0
)
else None ,
)
for ndx , prompt in enumerate ( run . inputs [ " prompts " ] or [ ] )
]
base_span . span_kind = self . trace_tree . SpanKind . LLM
return base_span
def _convert_chain_run_to_wb_span ( self , run : Run ) - > " Span " :
""" Converts a LangChain Chain Run into a W&B Trace Span.
: param run : The LangChain Chain Run to convert .
: return : The converted W & B Trace Span .
"""
base_span = self . _convert_run_to_wb_span ( run )
base_span . results = [
self . trace_tree . Result (
inputs = _serialize_inputs ( run . inputs ) , outputs = run . outputs
)
else None ,
]
base_span . child_spans = [
self . _convert_lc_run_to_wb_span ( child_run ) for child_run in run . child_runs
]
base_span . span_kind = (
self . trace_tree . SpanKind . AGENT
if " agent " in run . name . lower ( )
else self . trace_tree . SpanKind . CHAIN
)
for ndx , prompt in enumerate ( run . inputs [ " prompts " ] or [ ] )
]
base_span . span_kind = trace_tree . SpanKind . LLM
return base_span
return base_span
def _convert_tool_run_to_wb_span ( self , run : Run ) - > " Span " :
""" Converts a LangChain Tool Run into a W&B Trace Span.
: param run : The LangChain Tool Run to convert .
: return : The converted W & B Trace Span .
"""
base_span = self . _convert_run_to_wb_span ( run )
base_span . results = [
self . trace_tree . Result (
inputs = _serialize_inputs ( run . inputs ) , outputs = run . outputs
)
]
base_span . child_spans = [
self . _convert_lc_run_to_wb_span ( child_run ) for child_run in run . child_runs
]
base_span . span_kind = self . trace_tree . SpanKind . TOOL
return base_span
def _convert_lc_run_to_wb_span ( self , run : Run ) - > " Span " :
""" Utility to convert any generic LangChain Run into a W&B Trace Span.
: param run : The LangChain Run to convert .
: return : The converted W & B Trace Span .
"""
if run . run_type == RunTypeEnum . llm :
return self . _convert_llm_run_to_wb_span ( run )
elif run . run_type == RunTypeEnum . chain :
return self . _convert_chain_run_to_wb_span ( run )
elif run . run_type == RunTypeEnum . tool :
return self . _convert_tool_run_to_wb_span ( run )
else :
return self . _convert_run_to_wb_span ( run )
def process_model ( self , run : Run ) - > Optional [ Dict [ str , Any ] ] :
""" Utility to process a run for wandb model_dict serialization.
: param run : The run to process .
: return : The convert model_dict to pass to WBTraceTree .
"""
try :
data = json . loads ( run . json ( ) )
processed = self . flatten_run ( data )
keep_keys = (
" id " ,
" name " ,
" serialized " ,
" inputs " ,
" outputs " ,
" parent_run_id " ,
" execution_order " ,
)
processed = self . truncate_run_iterative ( processed , keep_keys = keep_keys )
exact_keys , partial_keys = ( " lc " , " type " ) , ( " api_key " , )
processed = self . modify_serialized_iterative (
processed , exact_keys = exact_keys , partial_keys = partial_keys
)
output = self . build_tree ( processed )
return output
except Exception as e :
if PRINT_WARNINGS :
self . wandb . termwarn ( f " WARNING: Failed to serialize model: { e } " )
return None
def flatten_run ( self , run : Dict [ str , Any ] ) - > List [ Dict [ str , Any ] ] :
""" Utility to flatten a nest run object into a list of runs.
: param run : The base run to flatten .
: return : The flattened list of runs .
"""
def flatten ( child_runs : List [ Dict [ str , Any ] ] ) - > List [ Dict [ str , Any ] ] :
""" Utility to recursively flatten a list of child runs in a run.
: param child_runs : The list of child runs to flatten .
: return : The flattened list of runs .
"""
if child_runs is None :
return [ ]
result = [ ]
for item in child_runs :
child_runs = item . pop ( " child_runs " , [ ] )
result . append ( item )
result . extend ( flatten ( child_runs ) )
return result
return flatten ( [ run ] )
def truncate_run_iterative (
self , runs : List [ Dict [ str , Any ] ] , keep_keys : Tuple [ str , . . . ] = ( )
) - > List [ Dict [ str , Any ] ] :
""" Utility to truncate a list of runs dictionaries to only keep the specified
keys in each run .
: param runs : The list of runs to truncate .
: param keep_keys : The keys to keep in each run .
: return : The truncated list of runs .
"""
def truncate_single ( run : Dict [ str , Any ] ) - > Dict [ str , Any ] :
""" Utility to truncate a single run dictionary to only keep the specified
keys .
: param run : The run dictionary to truncate .
: return : The truncated run dictionary
"""
new_dict = { }
for key in run :
if key in keep_keys :
new_dict [ key ] = run . get ( key )
return new_dict
return list ( map ( truncate_single , runs ) )
def modify_serialized_iterative (
self ,
runs : List [ Dict [ str , Any ] ] ,
exact_keys : Tuple [ str , . . . ] = ( ) ,
partial_keys : Tuple [ str , . . . ] = ( ) ,
) - > List [ Dict [ str , Any ] ] :
""" Utility to modify the serialized field of a list of runs dictionaries.
removes any keys that match the exact_keys and any keys that contain any of the
partial_keys .
recursively moves the dictionaries under the kwargs key to the top level .
changes the " id " field to a string " _kind " field that tells WBTraceTree how to
visualize the run . promotes the " serialized " field to the top level .
: param runs : The list of runs to modify .
: param exact_keys : A tuple of keys to remove from the serialized field .
: param partial_keys : A tuple of partial keys to remove from the serialized
field .
: return : The modified list of runs .
"""
def _serialize_inputs ( run_inputs : dict ) - > Union [ dict , list ] :
if " input_documents " in run_inputs :
docs = run_inputs [ " input_documents " ]
return [ doc . json ( ) for doc in docs ]
else :
return run_inputs
def remove_exact_and_partial_keys ( obj : Dict [ str , Any ] ) - > Dict [ str , Any ] :
""" Recursively removes exact and partial keys from a dictionary.
: param obj : The dictionary to remove keys from .
: return : The modified dictionary .
"""
if isinstance ( obj , dict ) :
obj = {
k : v
for k , v in obj . items ( )
if k not in exact_keys
and not any ( partial in k for partial in partial_keys )
}
for k , v in obj . items ( ) :
obj [ k ] = remove_exact_and_partial_keys ( v )
elif isinstance ( obj , list ) :
obj = [ remove_exact_and_partial_keys ( x ) for x in obj ]
return obj
def handle_id_and_kwargs (
obj : Dict [ str , Any ] , root : bool = False
) - > Dict [ str , Any ] :
""" Recursively handles the id and kwargs fields of a dictionary.
changes the id field to a string " _kind " field that tells WBTraceTree how
to visualize the run . recursively moves the dictionaries under the kwargs
key to the top level .
: param obj : a run dictionary with id and kwargs fields .
: param root : whether this is the root dictionary or the serialized
dictionary .
: return : The modified dictionary .
"""
if isinstance ( obj , dict ) :
if ( " id " in obj or " name " in obj ) and not root :
_kind = obj . get ( " id " )
if not _kind :
_kind = [ obj . get ( " name " ) ]
obj [ " _kind " ] = _kind [ - 1 ]
obj . pop ( " id " , None )
obj . pop ( " name " , None )
if " kwargs " in obj :
kwargs = obj . pop ( " kwargs " )
for k , v in kwargs . items ( ) :
obj [ k ] = v
for k , v in obj . items ( ) :
obj [ k ] = handle_id_and_kwargs ( v )
elif isinstance ( obj , list ) :
obj = [ handle_id_and_kwargs ( x ) for x in obj ]
return obj
def transform_serialized ( serialized : Dict [ str , Any ] ) - > Dict [ str , Any ] :
""" Transforms the serialized field of a run dictionary to be compatible
with WBTraceTree .
: param serialized : The serialized field of a run dictionary .
: return : The transformed serialized field .
"""
serialized = handle_id_and_kwargs ( serialized , root = True )
serialized = remove_exact_and_partial_keys ( serialized )
return serialized
def transform_run ( run : Dict [ str , Any ] ) - > Dict [ str , Any ] :
""" Transforms a run dictionary to be compatible with WBTraceTree.
: param run : The run dictionary to transform .
: return : The transformed run dictionary .
"""
transformed_dict = transform_serialized ( run )
serialized = transformed_dict . pop ( " serialized " )
for k , v in serialized . items ( ) :
transformed_dict [ k ] = v
_kind = transformed_dict . get ( " _kind " , None )
name = transformed_dict . pop ( " name " , None )
exec_ord = transformed_dict . pop ( " execution_order " , None )
if not name :
name = _kind
output_dict = {
f " { exec_ord } _ { name } " : transformed_dict ,
}
return output_dict
return list ( map ( transform_run , runs ) )
def _convert_chain_run_to_wb_span ( trace_tree : Any , run : Run ) - > trace_tree . Span :
base_span = _convert_run_to_wb_span ( trace_tree , run )
base_span . results = [
trace_tree . Result ( inputs = _serialize_inputs ( run . inputs ) , outputs = run . outputs )
]
base_span . child_spans = [
_convert_lc_run_to_wb_span ( trace_tree , child_run )
for child_run in run . child_runs
]
base_span . span_kind = (
trace_tree . SpanKind . AGENT
if " agent " in run . serialized . get ( " name " , " " ) . lower ( )
else trace_tree . SpanKind . CHAIN
)
return base_span
def _convert_tool_run_to_wb_span ( trace_tree : Any , run : Run ) - > trace_tree . Span :
base_span = _convert_run_to_wb_span ( trace_tree , run )
base_span . results = [
trace_tree . Result ( inputs = _serialize_inputs ( run . inputs ) , outputs = run . outputs )
]
base_span . child_spans = [
_convert_lc_run_to_wb_span ( trace_tree , child_run )
for child_run in run . child_runs
]
base_span . span_kind = trace_tree . SpanKind . TOOL
return base_span
def _convert_run_to_wb_span ( trace_tree : Any , run : Run ) - > trace_tree . Span :
attributes = { * * run . extra } if run . extra else { }
attributes [ " execution_order " ] = run . execution_order
return trace_tree . Span (
span_id = str ( run . id ) if run . id is not None else None ,
name = run . serialized . get ( " name " ) ,
start_time_ms = int ( run . start_time . timestamp ( ) * 1000 ) ,
end_time_ms = int ( run . end_time . timestamp ( ) * 1000 ) ,
status_code = trace_tree . StatusCode . SUCCESS
if run . error is None
else trace_tree . StatusCode . ERROR ,
status_message = run . error ,
attributes = attributes ,
)
def _replace_type_with_kind ( data : Any ) - > Any :
if isinstance ( data , dict ) :
# W&B TraceTree expects "_kind" instead of "_type" since `_type` is special
# in W&B.
if " _type " in data :
_type = data . pop ( " _type " )
data [ " _kind " ] = _type
return { k : _replace_type_with_kind ( v ) for k , v in data . items ( ) }
elif isinstance ( data , list ) :
return [ _replace_type_with_kind ( v ) for v in data ]
elif isinstance ( data , tuple ) :
return tuple ( _replace_type_with_kind ( v ) for v in data )
elif isinstance ( data , set ) :
return { _replace_type_with_kind ( v ) for v in data }
else :
return data
def build_tree ( self , runs : List [ Dict [ str , Any ] ] ) - > Dict [ str , Any ] :
""" Builds a nested dictionary from a list of runs.
: param runs : The list of runs to build the tree from .
: return : The nested dictionary representing the langchain Run in a tree
structure compatible with WBTraceTree .
"""
id_to_data = { }
child_to_parent = { }
for entity in runs :
for key , data in entity . items ( ) :
id_val = data . pop ( " id " , None )
parent_run_id = data . pop ( " parent_run_id " , None )
id_to_data [ id_val ] = { key : data }
if parent_run_id :
child_to_parent [ id_val ] = parent_run_id
for child_id , parent_id in child_to_parent . items ( ) :
parent_dict = id_to_data [ parent_id ]
parent_dict [ next ( iter ( parent_dict ) ) ] [
next ( iter ( id_to_data [ child_id ] ) )
] = id_to_data [ child_id ] [ next ( iter ( id_to_data [ child_id ] ) ) ]
root_dict = next (
data for id_val , data in id_to_data . items ( ) if id_val not in child_to_parent
)
return root_dict
class WandbRunArgs ( TypedDict ) :
@ -201,12 +438,13 @@ class WandbTracer(BaseTracer):
except ImportError as e :
raise ImportError (
" Could not import wandb python package. "
" Please install it with `pip install wandb`."
" Please install it with `pip install -U wandb`."
) from e
self . _wandb = wandb
self . _trace_tree = trace_tree
self . _run_args = run_args
self . _ensure_run ( should_print_url = ( wandb . run is None ) )
self . run_processor = RunProcessor ( self . _wandb , self . _trace_tree )
def finish ( self ) - > None :
""" Waits for all asynchronous processes to finish and data to upload.
@ -219,23 +457,11 @@ class WandbTracer(BaseTracer):
""" Logs a LangChain Run to W*B as a W&B Trace. """
self . _ensure_run ( )
try :
root_span = _convert_lc_run_to_wb_span ( self . _trace_tree , run )
except Exception as e :
if PRINT_WARNINGS :
self . _wandb . termwarn (
f " Skipping trace saving - unable to safely convert LangChain Run "
f " into W&B Trace due to: { e } "
)
return
model_dict = None
root_span = self . run_processor . process_span ( run )
model_dict = self . run_processor . process_model ( run )
# TODO: Add something like this once we have a way to get the clean serialized
# parent dict from a run:
# serialized_parent = safely_get_span_producing_model(run)
# if serialized_parent is not None:
# model_dict = safely_convert_model_to_dict(serialized_parent)
if root_span is None :
return
model_trace = self . _trace_tree . WBTraceTree (
root_span = root_span ,